Skip to content

perf: Speed up tree transforms during sql compile #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.ordering as bf_ordering
import bigframes.core.rewrite as rewrites

if typing.TYPE_CHECKING:
import bigframes.core
Expand All @@ -48,20 +49,32 @@ class Compiler:
# In unstrict mode, ordering from ReadTable or after joins may be ambiguous to improve query performance.
strict: bool = True
scalar_op_compiler = compile_scalar.ScalarOpCompiler()
enable_pruning: bool = False

def _preprocess(self, node: nodes.BigFrameNode):
if self.enable_pruning:
used_fields = frozenset(field.id for field in node.fields)
node = node.prune(used_fields)
node = functools.cache(rewrites.replace_slice_ops)(node)
return node

def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR:
ir = typing.cast(compiled.OrderedIR, self.compile_node(node, True))
ir = typing.cast(
compiled.OrderedIR, self.compile_node(self._preprocess(node), True)
)
if self.strict:
assert ir.has_total_order
return ir

def compile_unordered_ir(self, node: nodes.BigFrameNode) -> compiled.UnorderedIR:
return typing.cast(compiled.UnorderedIR, self.compile_node(node, False))
return typing.cast(
compiled.UnorderedIR, self.compile_node(self._preprocess(node), False)
)

def compile_peak_sql(
self, node: nodes.BigFrameNode, n_rows: int
) -> typing.Optional[str]:
return self.compile_unordered_ir(node).peek_sql(n_rows)
return self.compile_unordered_ir(self._preprocess(node)).peek_sql(n_rows)

# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
@functools.lru_cache(maxsize=5000)
Expand Down
24 changes: 20 additions & 4 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,11 @@ def explicitly_ordered(self) -> bool:
def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
return replace(self, child=t(self.child))
transformed = replace(self, child=t(self.child))
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

@property
def order_ambiguous(self) -> bool:
Expand Down Expand Up @@ -350,9 +354,13 @@ def joins(self) -> bool:
def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
return replace(
transformed = replace(
self, left_child=t(self.left_child), right_child=t(self.right_child)
)
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

@property
def defines_namespace(self) -> bool:
Expand Down Expand Up @@ -407,7 +415,11 @@ def variables_introduced(self) -> int:
def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
return replace(self, children=tuple(t(child) for child in self.children))
transformed = replace(self, children=tuple(t(child) for child in self.children))
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
# TODO: Make concat prunable, probably by redefining
Expand Down Expand Up @@ -451,7 +463,11 @@ def variables_introduced(self) -> int:
def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
return replace(self, start=t(self.start), end=t(self.end))
transformed = replace(self, start=t(self.start), end=t(self.end))
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
# TODO: Make FromRangeNode prunable (or convert to other node types)
Expand Down
3 changes: 3 additions & 0 deletions bigframes/core/tree_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def _node_counts_inner(

node_counts = _node_counts_inner(root)

if len(node_counts) == 0:
raise ValueError("node counts should be non-zero")

return max(
node_counts.keys(),
key=lambda node: heuristic(
Expand Down
34 changes: 13 additions & 21 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import bigframes.core.identifiers
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
import bigframes.core.rewrite as rewrites
import bigframes.core.schema
import bigframes.core.tree_properties as tree_properties
import bigframes.features
Expand Down Expand Up @@ -128,7 +127,7 @@ def to_sql(
col_id_overrides = dict(col_id_overrides)
col_id_overrides[internal_offset_col] = offset_column
node = (
self._get_optimized_plan(array_value.node)
self._sub_cache_subtrees(array_value.node)
if enable_cache
else array_value.node
)
Expand Down Expand Up @@ -279,7 +278,7 @@ def peek(
"""
A 'peek' efficiently accesses a small number of rows in the dataframe.
"""
plan = self._get_optimized_plan(array_value.node)
plan = self._sub_cache_subtrees(array_value.node)
if not tree_properties.can_fast_peek(plan):
warnings.warn("Peeking this value cannot be done efficiently.")

Expand Down Expand Up @@ -314,15 +313,15 @@ def head(
# No user-provided ordering, so just get any N rows, its faster!
return self.peek(array_value, n_rows)

plan = self._get_optimized_plan(array_value.node)
plan = self._sub_cache_subtrees(array_value.node)
if not tree_properties.can_fast_head(plan):
# If can't get head fast, we are going to need to execute the whole query
# Will want to do this in a way such that the result is reusable, but the first
# N values can be easily extracted.
# This currently requires clustering on offsets.
self._cache_with_offsets(array_value)
# Get a new optimized plan after caching
plan = self._get_optimized_plan(array_value.node)
plan = self._sub_cache_subtrees(array_value.node)
assert tree_properties.can_fast_head(plan)

head_plan = generate_head_plan(plan, n_rows)
Expand All @@ -347,7 +346,7 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int:
if count is not None:
return count
else:
row_count_plan = self._get_optimized_plan(
row_count_plan = self._sub_cache_subtrees(
generate_row_count_plan(array_value.node)
)
sql = self.compiler.compile_unordered(row_count_plan)
Expand All @@ -359,7 +358,7 @@ def _local_get_row_count(
) -> Optional[int]:
# optimized plan has cache materializations which will have row count metadata
# that is more likely to be usable than original leaf nodes.
plan = self._get_optimized_plan(array_value.node)
plan = self._sub_cache_subtrees(array_value.node)
return tree_properties.row_count(plan)

# Helpers
Expand Down Expand Up @@ -424,21 +423,14 @@ def _wait_on_job(
self.metrics.count_job_stats(query_job)
return results_iterator

def _get_optimized_plan(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
def _sub_cache_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""
Takes the original expression tree and applies optimizations to accelerate execution.

At present, the only optimization is to replace subtress with cached previous materializations.
"""
# Apply any rewrites *after* applying cache, as cache is sensitive to exact tree structure
optimized_plan = tree_properties.replace_nodes(
node, (dict(self._cached_executions))
)
if ENABLE_PRUNING:
used_fields = frozenset(field.id for field in optimized_plan.fields)
optimized_plan = optimized_plan.prune(used_fields)
optimized_plan = rewrites.replace_slice_ops(optimized_plan)
return optimized_plan
return tree_properties.replace_nodes(node, (dict(self._cached_executions)))

def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
"""
Expand All @@ -448,7 +440,7 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
# Once rewriting is available, will want to rewrite before
# evaluating execution cost.
return tree_properties.is_trivially_executable(
self._get_optimized_plan(array_value.node)
self._sub_cache_subtrees(array_value.node)
)

def _cache_with_cluster_cols(
Expand All @@ -457,7 +449,7 @@ def _cache_with_cluster_cols(
"""Executes the query and uses the resulting table to rewrite future executions."""

sql, schema, ordering_info = self.compiler.compile_raw(
self._get_optimized_plan(array_value.node)
self._sub_cache_subtrees(array_value.node)
)
tmp_table = self._sql_as_cached_temp_table(
sql,
Expand All @@ -474,7 +466,7 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
"""Executes the query and uses the resulting table to rewrite future executions."""
offset_column = bigframes.core.guid.generate_guid("bigframes_offsets")
w_offsets, offset_column = array_value.promote_offsets()
sql = self.compiler.compile_unordered(self._get_optimized_plan(w_offsets.node))
sql = self.compiler.compile_unordered(self._sub_cache_subtrees(w_offsets.node))

tmp_table = self._sql_as_cached_temp_table(
sql,
Expand Down Expand Up @@ -510,7 +502,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
# Apply existing caching first
for _ in range(MAX_SUBTREE_FACTORINGS):
node_with_cache = self._get_optimized_plan(array_value.node)
node_with_cache = self._sub_cache_subtrees(array_value.node)
if node_with_cache.planning_complexity < QUERY_COMPLEXITY_LIMIT:
return

Expand Down Expand Up @@ -567,7 +559,7 @@ def _validate_result_schema(
):
actual_schema = tuple(bq_schema)
ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema(
self._get_optimized_plan(array_value.node)
self._sub_cache_subtrees(array_value.node)
)
internal_schema = array_value.schema
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
Expand Down