Skip to content

chore: support timestamp subtractions #1346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 5, 2025
4 changes: 4 additions & 0 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def compile_sql(
# TODO: get rid of output_ids arg
assert len(output_ids) == len(list(node.fields))
node = set_output_names(node, output_ids)
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
if ordered:
node, limit = rewrites.pullup_limit_from_slice(node)
node = nodes.bottom_up(node, rewrites.rewrite_slice)
Expand All @@ -81,6 +82,7 @@ def compile_sql(
def compile_peek_sql(self, node: nodes.BigFrameNode, n_rows: int) -> str:
ids = [id.sql for id in node.ids]
node = nodes.bottom_up(node, rewrites.rewrite_slice)
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
node, _ = rewrites.pull_up_order(
node, order_root=False, ordered_joins=self.strict
)
Expand All @@ -93,13 +95,15 @@ def compile_raw(
str, typing.Sequence[google.cloud.bigquery.SchemaField], bf_ordering.RowOrdering
]:
node = nodes.bottom_up(node, rewrites.rewrite_slice)
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
node, ordering = rewrites.pull_up_order(node, ordered_joins=self.strict)
ir = self.compile_node(node)
sql = ir.to_sql()
return sql, node.schema.to_bigquery(), ordering

def _preprocess(self, node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrites.rewrite_slice)
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
node, _ = rewrites.pull_up_order(
node, order_root=False, ordered_joins=self.strict
)
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/compile/ibis_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
BIGFRAMES_TO_IBIS: Dict[bigframes.dtypes.Dtype, ibis_dtypes.DataType] = {
pandas: ibis for ibis, pandas in BIDIRECTIONAL_MAPPINGS
}
BIGFRAMES_TO_IBIS.update({bigframes.dtypes.TIMEDETLA_DTYPE: ibis_dtypes.int64})
BIGFRAMES_TO_IBIS.update({bigframes.dtypes.TIMEDELTA_DTYPE: ibis_dtypes.int64})
IBIS_TO_BIGFRAMES: Dict[ibis_dtypes.DataType, bigframes.dtypes.Dtype] = {
ibis: pandas for ibis, pandas in BIDIRECTIONAL_MAPPINGS
}
Expand Down
5 changes: 5 additions & 0 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,11 @@ def unix_millis_op_impl(x: ibis_types.TimestampValue):
return unix_millis(x)


@scalar_op_compiler.register_binary_op(ops.timestamp_diff_op)
def timestamp_diff_op_impl(x: ibis_types.TimestampValue, y: ibis_types.TimestampValue):
return x.delta(y, "microsecond")


@scalar_op_compiler.register_unary_op(ops.FloorDtOp, pass_op=True)
def floor_dt_op_impl(x: ibis_types.Value, op: ops.FloorDtOp):
supported_freqs = ["Y", "Q", "M", "W", "D", "h", "min", "s", "ms", "us", "ns"]
Expand Down
2 changes: 2 additions & 0 deletions bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from bigframes.core.rewrite.identifiers import remap_variables
from bigframes.core.rewrite.implicit_align import try_row_join
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
from bigframes.core.rewrite.operators import rewrite_timedelta_ops
from bigframes.core.rewrite.order import pull_up_order
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice

__all__ = [
"legacy_join_as_projection",
"try_row_join",
"rewrite_slice",
"rewrite_timedelta_ops",
"pullup_limit_from_slice",
"remap_variables",
"pull_up_order",
Expand Down
82 changes: 82 additions & 0 deletions bigframes/core/rewrite/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import functools
import typing

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core import expression as ex
from bigframes.core import nodes, schema


@dataclasses.dataclass
class _TypedExpr:
expr: ex.Expression
dtype: dtypes.Dtype


def rewrite_timedelta_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""
Rewrites expressions to properly handle timedelta values, because this type does not exist
in the SQL world.
"""
if isinstance(root, nodes.ProjectionNode):
updated_assignments = tuple(
(_rewrite_expressions(expr, root.schema).expr, column_id)
for expr, column_id in root.assignments
)
root = nodes.ProjectionNode(root.child, updated_assignments)

# TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too.
return root
Comment on lines +43 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as we get support those nodes before anybody starts using this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR soon to follow!



@functools.cache
def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr:
if isinstance(expr, ex.DerefOp):
return _TypedExpr(expr, schema.get_type(expr.id.sql))

if isinstance(expr, ex.ScalarConstantExpression):
return _TypedExpr(expr, expr.dtype)

if isinstance(expr, ex.OpExpression):
updated_inputs = tuple(
map(lambda x: _rewrite_expressions(x, schema), expr.inputs)
)
return _rewrite_op_expr(expr, updated_inputs)
Comment on lines +55 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will also need to be top-down rather than bottom-up.

Copy link
Contributor Author

@sycai sycai Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's possible to do this top-down, because we cannot get the input types by first processing the parent node. The parent node output type can only be decided once we have rewrite all the subtrees.


raise AssertionError(f"Unexpected expression type: {type(expr)}")


def _rewrite_op_expr(
expr: ex.OpExpression, inputs: typing.Tuple[_TypedExpr, ...]
) -> _TypedExpr:
if isinstance(expr.op, ops.SubOp):
return _rewrite_sub_op(inputs[0], inputs[1])

input_types = tuple(map(lambda x: x.dtype, inputs))
return _TypedExpr(expr, expr.op.output_type(*input_types))


def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result_op: ops.BinaryOp = ops.sub_op
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
result_op = ops.timestamp_diff_op

return _TypedExpr(
result_op.as_expr(left.expr, right.expr),
result_op.output_type(left.dtype, right.dtype),
)
2 changes: 1 addition & 1 deletion bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
TIME_DTYPE = pd.ArrowDtype(pa.time64("us"))
DATETIME_DTYPE = pd.ArrowDtype(pa.timestamp("us"))
TIMESTAMP_DTYPE = pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
TIMEDETLA_DTYPE = pd.ArrowDtype(pa.duration("us"))
TIMEDELTA_DTYPE = pd.ArrowDtype(pa.duration("us"))
NUMERIC_DTYPE = pd.ArrowDtype(pa.decimal128(38, 9))
BIGNUMERIC_DTYPE = pd.ArrowDtype(pa.decimal256(76, 38))
# No arrow equivalent
Expand Down
4 changes: 4 additions & 0 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
date_op,
StrftimeOp,
time_op,
timestamp_diff_op,
ToDatetimeOp,
ToTimestampOp,
UnixMicros,
Expand Down Expand Up @@ -125,6 +126,7 @@
sinh_op,
sqrt_op,
sub_op,
SubOp,
tan_op,
tanh_op,
unsafe_pow_op,
Expand Down Expand Up @@ -246,6 +248,7 @@
# Datetime ops
"date_op",
"time_op",
"timestamp_diff_op",
"ToDatetimeOp",
"ToTimestampOp",
"StrftimeOp",
Expand Down Expand Up @@ -283,6 +286,7 @@
"sinh_op",
"sqrt_op",
"sub_op",
"SubOp",
"tan_op",
"tanh_op",
"unsafe_pow_op",
Expand Down
19 changes: 19 additions & 0 deletions bigframes/operations/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,22 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
if input_types[0] is not dtypes.TIMESTAMP_DTYPE:
raise TypeError("expected timestamp input")
return dtypes.INT_DTYPE


@dataclasses.dataclass(frozen=True)
class TimestampDiff(base_ops.BinaryOp):
name: typing.ClassVar[str] = "timestamp_diff"

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] is not input_types[1]:
raise TypeError(
f"two inputs have different types. left: {input_types[0]}, right: {input_types[1]}"
)

if not dtypes.is_datetime_like(input_types[0]):
raise TypeError("expected timestamp input")

return dtypes.TIMEDELTA_DTYPE


timestamp_diff_op = TimestampDiff()
5 changes: 4 additions & 1 deletion bigframes/operations/numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def output_type(self, *input_types):
):
# Numeric subtraction
return dtypes.coerce_to_common(left_type, right_type)
# TODO: Add temporal addition once delta types supported

if dtypes.is_datetime_like(left_type) and dtypes.is_datetime_like(right_type):
return dtypes.TIMEDELTA_DTYPE

raise TypeError(f"Cannot subtract dtypes {left_type} and {right_type}")


Expand Down
2 changes: 1 addition & 1 deletion bigframes/operations/timedelta_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ class ToTimedeltaOp(base_ops.UnaryOp):
def output_type(self, *input_types):
if input_types[0] is not dtypes.INT_DTYPE:
raise TypeError("expected integer input")
return dtypes.TIMEDETLA_DTYPE
return dtypes.TIMEDELTA_DTYPE
4 changes: 2 additions & 2 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,10 +805,10 @@ def __rsub__(self, other: float | int | Series) -> Series:

__rsub__.__doc__ = inspect.getdoc(vendored_pandas_series.Series.__rsub__)

def sub(self, other: float | int | Series) -> Series:
def sub(self, other) -> Series:
return self._apply_binary_op(other, ops.sub_op)

def rsub(self, other: float | int | Series) -> Series:
def rsub(self, other) -> Series:
return self._apply_binary_op(other, ops.sub_op, reverse=True)

subtract = sub
Expand Down
4 changes: 2 additions & 2 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def read_pandas_load_job(

destination_table = self._bqclient.get_table(load_table_destination)
col_type_overrides: typing.Dict[str, bigframes.dtypes.Dtype] = {
col: bigframes.dtypes.TIMEDETLA_DTYPE
col: bigframes.dtypes.TIMEDELTA_DTYPE
for col in df_and_labels.timedelta_cols
}
array_value = core.ArrayValue.from_table(
Expand Down Expand Up @@ -236,7 +236,7 @@ def read_pandas_streaming(
)

col_type_overrides: typing.Dict[str, bigframes.dtypes.Dtype] = {
col: bigframes.dtypes.TIMEDETLA_DTYPE
col: bigframes.dtypes.TIMEDELTA_DTYPE
for col in df_and_labels.timedelta_cols
}
array_value = (
Expand Down
81 changes: 81 additions & 0 deletions tests/system/small/operations/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import datetime

import numpy
from pandas import testing
import pandas as pd
import pytest

Expand Down Expand Up @@ -367,3 +369,82 @@ def test_dt_clip_coerce_str_timestamp(scalars_dfs):
pd_result,
bf_result,
)


@pytest.mark.parametrize("column", ["timestamp_col", "datetime_col"])
def test_timestamp_diff_two_series(scalars_dfs, column):
bf_df, pd_df = scalars_dfs
bf_series = bf_df[column]
pd_series = pd_df[column]

actual_result = (bf_series - bf_series).to_pandas()

expected_result = pd_series - pd_series
assert_series_equal(actual_result, expected_result)


@pytest.mark.parametrize("column", ["timestamp_col", "datetime_col"])
def test_timestamp_diff_two_series_with_numpy_ops(scalars_dfs, column):
bf_df, pd_df = scalars_dfs
bf_series = bf_df[column]
pd_series = pd_df[column]

actual_result = numpy.subtract(bf_series, bf_series).to_pandas()

expected_result = numpy.subtract(pd_series, pd_series)
assert_series_equal(actual_result, expected_result)


def test_timestamp_diff_two_dataframes(scalars_dfs):
columns = ["timestamp_col", "datetime_col"]
bf_df, pd_df = scalars_dfs
bf_df = bf_df[columns]
pd_df = pd_df[columns]

actual_result = (bf_df - bf_df).to_pandas()

expected_result = pd_df - pd_df
testing.assert_frame_equal(actual_result, expected_result)


def test_timestamp_diff_two_series_with_different_types_raise_error(scalars_dfs):
df, _ = scalars_dfs

with pytest.raises(TypeError):
(df["timestamp_col"] - df["datetime_col"]).to_pandas()


@pytest.mark.parametrize(
("column", "value"),
[
("timestamp_col", pd.Timestamp("2025-01-01 00:00:01", tz="America/New_York")),
("datetime_col", datetime.datetime(2025, 1, 1, 0, 0, 1)),
],
)
def test_timestamp_diff_series_sub_literal(scalars_dfs, column, value):
bf_df, pd_df = scalars_dfs
bf_series = bf_df[column]
pd_series = pd_df[column]

actual_result = (bf_series - value).to_pandas()

expected_result = pd_series - value
assert_series_equal(actual_result, expected_result)


@pytest.mark.parametrize(
("column", "value"),
[
("timestamp_col", pd.Timestamp("2025-01-01 00:00:01", tz="America/New_York")),
("datetime_col", datetime.datetime(2025, 1, 1, 0, 0, 1)),
],
)
def test_timestamp_diff_literal_sub_series(scalars_dfs, column, value):
bf_df, pd_df = scalars_dfs
bf_series = bf_df[column]
pd_series = pd_df[column]

actual_result = (value - bf_series).to_pandas()

expected_result = value - pd_series
assert_series_equal(actual_result, expected_result)