Skip to content

Commit c0cefd3

Browse files
feat: Add isin local execution to hybrid engine (#1915)
1 parent c57f04e commit c0cefd3

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

‎bigframes/core/compile/polars/compiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,30 @@ def compile_join(self, node: nodes.JoinNode):
513513
left, right, node.type, left_on, right_on, node.joins_nulls
514514
)
515515

516+
@compile_node.register
517+
def compile_isin(self, node: nodes.InNode):
518+
left = self.compile_node(node.left_child)
519+
right = self.compile_node(node.right_child).unique(node.right_col.id.sql)
520+
right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql))
521+
522+
left_ex, right_ex = lowering._coerce_comparables(node.left_col, node.right_col)
523+
524+
left_pl_ex = self.expr_compiler.compile_expression(left_ex)
525+
right_pl_ex = self.expr_compiler.compile_expression(right_ex)
526+
527+
joined = left.join(
528+
right,
529+
how="left",
530+
left_on=left_pl_ex,
531+
right_on=right_pl_ex,
532+
# Note: join_nulls renamed to nulls_equal for polars 1.24
533+
join_nulls=node.joins_nulls, # type: ignore
534+
coalesce=False,
535+
)
536+
passthrough = [pl.col(id) for id in left.columns]
537+
indicator = pl.col(node.indicator_col.sql).fill_null(False)
538+
return joined.select((*passthrough, indicator))
539+
516540
def _ordered_join(
517541
self,
518542
left_frame: pl.LazyFrame,

‎bigframes/core/rewrite/schema_binding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def bind_schema_to_node(
6565
node,
6666
conditions=conditions,
6767
)
68+
if isinstance(node, nodes.InNode):
69+
return dataclasses.replace(
70+
node,
71+
left_col=ex.ResolvedDerefOp.from_field(
72+
node.left_child.field_by_id[node.left_col.id]
73+
),
74+
right_col=ex.ResolvedDerefOp.from_field(
75+
node.right_child.field_by_id[node.right_col.id]
76+
),
77+
)
6878

6979
if isinstance(node, nodes.AggregateNode):
7080
aggregations = []

‎bigframes/session/polars_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
nodes.FilterNode,
4040
nodes.ConcatNode,
4141
nodes.JoinNode,
42+
nodes.InNode,
4243
)
4344

4445
_COMPATIBLE_SCALAR_OPS = (

‎tests/system/small/engines/test_join.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,22 @@ def test_engines_cross_join(
8888
result, _ = scalars_array_value.relational_join(scalars_array_value, type="cross")
8989

9090
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
91+
92+
93+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
94+
@pytest.mark.parametrize(
95+
("left_key", "right_key"),
96+
[
97+
("int64_col", "float64_col"),
98+
("float64_col", "int64_col"),
99+
("int64_too", "int64_col"),
100+
],
101+
)
102+
def test_engines_isin(
103+
scalars_array_value: array_value.ArrayValue, engine, left_key, right_key
104+
):
105+
result, _ = scalars_array_value.isin(
106+
scalars_array_value, lcol=left_key, rcol=right_key
107+
)
108+
109+
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)