Skip to content

perf: Reduce schema tracking overhead #1056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 48 additions & 37 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def roots(self) -> typing.Set[BigFrameNode]:
)
return set(roots)

# TODO: For deep trees, this can create a lot of overhead, maybe use zero-copy persistent datastructure?
# TODO: Store some local data lazily for select, aggregate nodes.
@property
@abc.abstractmethod
def fields(self) -> Tuple[Field, ...]:
def fields(self) -> Iterable[Field]:
...

@property
Expand Down Expand Up @@ -252,8 +252,8 @@ class UnaryNode(BigFrameNode):
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.child,)

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
@property
def fields(self) -> Iterable[Field]:
return self.child.fields

@property
Expand Down Expand Up @@ -303,9 +303,9 @@ def explicitly_ordered(self) -> bool:
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
return False

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(itertools.chain(self.left_child.fields, self.right_child.fields))
@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.left_child.fields, self.right_child.fields)

@functools.cached_property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -360,10 +360,10 @@ def explicitly_ordered(self) -> bool:
# Consider concat as an ordered operations (even though input frames may not be ordered)
return True

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
@property
def fields(self) -> Iterable[Field]:
# TODO: Output names should probably be aligned beforehand or be part of concat definition
return tuple(
return (
Field(bfet_ids.ColumnId(f"column_{i}"), field.dtype)
for i, field in enumerate(self.children[0].fields)
)
Expand Down Expand Up @@ -407,8 +407,10 @@ def explicitly_ordered(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return (Field(bfet_ids.ColumnId("labels"), self.start.fields[0].dtype),)
def fields(self) -> Iterable[Field]:
return (
Field(bfet_ids.ColumnId("labels"), next(iter(self.start.fields)).dtype),
)

@functools.cached_property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -469,11 +471,11 @@ class ReadLocalNode(LeafNode):
scan_list: ScanList
session: typing.Optional[bigframes.session.Session] = None

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
@property
def fields(self) -> Iterable[Field]:
return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)

@functools.cached_property
@property
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return len(self.scan_list.items) + 1
Expand Down Expand Up @@ -576,9 +578,9 @@ def __post_init__(self):
def session(self):
return self.table_session

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
@property
def fields(self) -> Iterable[Field]:
return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)

@property
def relation_ops_created(self) -> int:
Expand Down Expand Up @@ -644,8 +646,10 @@ def non_local(self) -> bool:
return True

@property
def fields(self) -> Tuple[Field, ...]:
return (*self.child.fields, Field(self.col_id, bigframes.dtypes.INT_DTYPE))
def fields(self) -> Iterable[Field]:
return itertools.chain(
self.child.fields, [Field(self.col_id, bigframes.dtypes.INT_DTYPE)]
)

@property
def relation_ops_created(self) -> int:
Expand Down Expand Up @@ -729,7 +733,7 @@ class SelectionNode(UnaryNode):
]

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
def fields(self) -> Iterable[Field]:
return tuple(
Field(output, self.child.get_type(input.id))
for input, output in self.input_output_pairs
Expand Down Expand Up @@ -774,13 +778,16 @@ def __post_init__(self):
assert all(name not in self.child.schema.names for _, name in self.assignments)

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
def added_fields(self) -> Tuple[Field, ...]:
input_types = self.child._dtype_lookup
new_fields = (
return tuple(
Field(id, bigframes.dtypes.dtype_for_etype(ex.output_type(input_types)))
for ex, id in self.assignments
)
return (*self.child.fields, *new_fields)

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.child.fields, self.added_fields)

@property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -811,8 +818,8 @@ def row_preserving(self) -> bool:
def non_local(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
@property
def fields(self) -> Iterable[Field]:
return (Field(bfet_ids.ColumnId("count"), bigframes.dtypes.INT_DTYPE),)

@property
Expand Down Expand Up @@ -841,7 +848,7 @@ def non_local(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
def fields(self) -> Iterable[Field]:
by_items = (
Field(ref.id, self.child.get_type(ref.id)) for ref in self.by_column_ids
)
Expand All @@ -854,7 +861,7 @@ def fields(self) -> Tuple[Field, ...]:
)
for agg, id in self.aggregations
)
return (*by_items, *agg_items)
return tuple(itertools.chain(by_items, agg_items))

@property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -896,11 +903,9 @@ class WindowOpNode(UnaryNode):
def non_local(self) -> bool:
return True

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
input_type = self.child.get_type(self.column_name.id)
new_item_dtype = self.op.output_type(input_type)
return (*self.child.fields, Field(self.output_name, new_item_dtype))
@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.child.fields, [self.added_field])

@property
def variables_introduced(self) -> int:
Expand All @@ -911,6 +916,12 @@ def relation_ops_created(self) -> int:
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
return 0 if self.skip_reproject_unsafe else 4

@functools.cached_property
def added_field(self) -> Field:
input_type = self.child.get_type(self.column_name.id)
new_item_dtype = self.op.output_type(input_type)
return Field(self.output_name, new_item_dtype)

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
if self.output_name not in used_cols:
return self.child
Expand Down Expand Up @@ -959,9 +970,9 @@ class ExplodeNode(UnaryNode):
def row_preserving(self) -> bool:
return False

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(
@property
def fields(self) -> Iterable[Field]:
return (
Field(
field.id,
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
Expand Down