Skip to content

Commit e5a8866

Browse files
feat: Hybrid engine local join support
1 parent 05cb7d0 commit e5a8866

File tree

3 files changed

+101
-4
lines changed

3 files changed

+101
-4
lines changed

‎bigframes/core/compile/polars/compiler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,14 @@ def compile_offsets(self, node: nodes.PromoteOffsetsNode):
487487
def compile_join(self, node: nodes.JoinNode):
488488
left = self.compile_node(node.left_child)
489489
right = self.compile_node(node.right_child)
490-
left_on = [l_name.id.sql for l_name, _ in node.conditions]
491-
right_on = [r_name.id.sql for _, r_name in node.conditions]
490+
491+
left_on = []
492+
right_on = []
493+
for left_ex, right_ex in node.conditions:
494+
left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex)
495+
left_on.append(self.expr_compiler.compile_expression(left_ex))
496+
right_on.append(self.expr_compiler.compile_expression(right_ex))
497+
492498
if node.type == "right":
493499
return self._ordered_join(
494500
right, left, "left", right_on, left_on, node.joins_nulls
@@ -502,8 +508,8 @@ def _ordered_join(
502508
left_frame: pl.LazyFrame,
503509
right_frame: pl.LazyFrame,
504510
how: Literal["inner", "outer", "left", "cross"],
505-
left_on: Sequence[str],
506-
right_on: Sequence[str],
511+
left_on: Sequence[pl.Expr],
512+
right_on: Sequence[pl.Expr],
507513
join_nulls: bool,
508514
):
509515
if how == "right":

‎bigframes/session/polars_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
nodes.AggregateNode,
3838
nodes.FilterNode,
3939
nodes.ConcatNode,
40+
nodes.JoinNode,
4041
)
4142

4243
_COMPATIBLE_SCALAR_OPS = (
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Literal
16+
17+
import pytest
18+
19+
from bigframes import operations as ops
20+
from bigframes.core import array_value, expression, ordering
21+
from bigframes.session import polars_executor
22+
from bigframes.testing.engine_utils import assert_equivalence_execution
23+
24+
pytest.importorskip("polars")
25+
26+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
27+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
28+
29+
30+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
31+
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
32+
def test_engines_join_on_key(
33+
scalars_array_value: array_value.ArrayValue,
34+
engine,
35+
join_type: Literal["inner", "outer", "left", "right"],
36+
):
37+
result, _ = scalars_array_value.relational_join(
38+
scalars_array_value, conditions=(("int64_col", "int64_col"),), type=join_type
39+
)
40+
41+
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
42+
43+
44+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
45+
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
46+
def test_engines_join_on_coerced_key(
47+
scalars_array_value: array_value.ArrayValue,
48+
engine,
49+
join_type: Literal["inner", "outer", "left", "right"],
50+
):
51+
result, _ = scalars_array_value.relational_join(
52+
scalars_array_value, conditions=(("int64_col", "float64_col"),), type=join_type
53+
)
54+
55+
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
56+
57+
58+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
59+
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
60+
def test_engines_join_multi_key(
61+
scalars_array_value: array_value.ArrayValue,
62+
engine,
63+
join_type: Literal["inner", "outer", "left", "right"],
64+
):
65+
l_input = scalars_array_value.order_by([ordering.ascending_over("float64_col")])
66+
l_input, l_join_cols = scalars_array_value.compute_values(
67+
[
68+
ops.mod_op.as_expr("int64_col", expression.const(2)),
69+
ops.invert_op.as_expr("bool_col"),
70+
]
71+
)
72+
r_input, r_join_cols = scalars_array_value.compute_values(
73+
[ops.mod_op.as_expr("int64_col", expression.const(3)), expression.const(True)]
74+
)
75+
76+
conditions = tuple((l_col, r_col) for l_col, r_col in zip(l_join_cols, r_join_cols))
77+
78+
result, _ = l_input.relational_join(r_input, conditions=conditions, type=join_type)
79+
80+
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
81+
82+
83+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
84+
def test_engines_cross_join(
85+
scalars_array_value: array_value.ArrayValue,
86+
engine,
87+
):
88+
result, _ = scalars_array_value.relational_join(scalars_array_value, type="cross")
89+
90+
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)