Skip to content

feat: Add total_rows property to pandas batches iterator #1888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 55 additions & 42 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,17 @@
import random
import textwrap
import typing
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import warnings

import bigframes_vendored.constants as constants
Expand Down Expand Up @@ -87,14 +97,22 @@
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]


class BlockHolder(typing.Protocol):
@dataclasses.dataclass
class PandasBatches(Iterator[pd.DataFrame]):
"""Interface for mutable objects with state represented by a block value object."""

def _set_block(self, block: Block):
"""Set the underlying block value of the object"""
def __init__(
self, pandas_batches: Iterator[pd.DataFrame], total_rows: Optional[int] = 0
):
self._dataframes: Iterator[pd.DataFrame] = pandas_batches
self._total_rows: Optional[int] = total_rows

@property
def total_rows(self) -> Optional[int]:
return self._total_rows

def _get_block(self) -> Block:
"""Get the underlying block value of the object"""
def __next__(self) -> pd.DataFrame:
return next(self._dataframes)


@dataclasses.dataclass()
Expand Down Expand Up @@ -599,8 +617,7 @@ def try_peek(
self.expr, n, use_explicit_destination=allow_large_results
)
df = result.to_pandas()
self._copy_index_to_pandas(df)
return df
return self._copy_index_to_pandas(df)
else:
return None

Expand All @@ -609,8 +626,7 @@ def to_pandas_batches(
page_size: Optional[int] = None,
max_results: Optional[int] = None,
allow_large_results: Optional[bool] = None,
squeeze: Optional[bool] = False,
):
) -> Iterator[pd.DataFrame]:
"""Download results one message at a time.

page_size and max_results determine the size and number of batches,
Expand All @@ -621,43 +637,43 @@ def to_pandas_batches(
use_explicit_destination=allow_large_results,
)

total_batches = 0
for df in execute_result.to_pandas_batches(
page_size=page_size, max_results=max_results
):
total_batches += 1
self._copy_index_to_pandas(df)
if squeeze:
yield df.squeeze(axis=1)
else:
yield df

# To reduce the number of edge cases to consider when working with the
# results of this, always return at least one DataFrame. See:
# b/428918844.
if total_batches == 0:
df = pd.DataFrame(
{
col: pd.Series([], dtype=self.expr.get_column_type(col))
for col in itertools.chain(self.value_columns, self.index_columns)
}
)
self._copy_index_to_pandas(df)
yield df
empty_val = pd.DataFrame(
{
col: pd.Series([], dtype=self.expr.get_column_type(col))
for col in itertools.chain(self.value_columns, self.index_columns)
}
)
dfs = map(
lambda a: a[0],
itertools.zip_longest(
execute_result.to_pandas_batches(page_size, max_results),
[0],
fillvalue=empty_val,
),
)
dfs = iter(map(self._copy_index_to_pandas, dfs))

def _copy_index_to_pandas(self, df: pd.DataFrame):
"""Set the index on pandas DataFrame to match this block.
total_rows = execute_result.total_rows
if (total_rows is not None) and (max_results is not None):
total_rows = min(total_rows, max_results)

Warning: This method modifies ``df`` inplace.
"""
return PandasBatches(dfs, total_rows)

def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
"""Set the index on pandas DataFrame to match this block."""
# Note: If BigQuery DataFrame has null index, a default one will be created for the local materialization.
new_df = df.copy()
if len(self.index_columns) > 0:
df.set_index(list(self.index_columns), inplace=True)
new_df.set_index(list(self.index_columns), inplace=True)
# Pandas names is annotated as list[str] rather than the more
# general Sequence[Label] that BigQuery DataFrames has.
# See: https://github.com/pandas-dev/pandas-stubs/issues/804
df.index.names = self.index.names # type: ignore
df.columns = self.column_labels
new_df.index.names = self.index.names # type: ignore
new_df.columns = self.column_labels
return new_df

def _materialize_local(
self, materialize_options: MaterializationOptions = MaterializationOptions()
Expand Down Expand Up @@ -724,9 +740,7 @@ def _materialize_local(
)
else:
df = execute_result.to_pandas()
self._copy_index_to_pandas(df)

return df, execute_result.query_job
return self._copy_index_to_pandas(df), execute_result.query_job

def _downsample(
self, total_rows: int, sampling_method: str, fraction: float, random_state
Expand Down Expand Up @@ -1591,8 +1605,7 @@ def retrieve_repr_request_results(
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()

head_df = head_result.to_pandas()
self._copy_index_to_pandas(head_df)
return head_df, row_count, head_result.query_job
return self._copy_index_to_pandas(head_df), row_count, head_result.query_job

def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
expr, result_id = self._expr.promote_offsets()
Expand Down
5 changes: 2 additions & 3 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,13 +648,12 @@ def to_pandas_batches(
form the original Series. Results stream from bigquery,
see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.table.RowIterator#google_cloud_bigquery_table_RowIterator_to_arrow_iterable
"""
df = self._block.to_pandas_batches(
batches = self._block.to_pandas_batches(
page_size=page_size,
max_results=max_results,
allow_large_results=allow_large_results,
squeeze=True,
)
return df
return map(lambda df: cast(pandas.Series, df.squeeze(1)), batches)

def _compute_dry_run(self) -> bigquery.QueryJob:
_, query_job = self._block._compute_dry_run((self._value_column,))
Expand Down
15 changes: 15 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,21 @@ def test_filter_df(scalars_dfs):
assert_pandas_df_equal(bf_result, pd_result)


def test_df_to_pandas_batches(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs

capped_unfiltered_batches = scalars_df.to_pandas_batches(page_size=2, max_results=6)
bf_bool_series = scalars_df["bool_col"]
filtered_batches = scalars_df[bf_bool_series].to_pandas_batches()

pd_bool_series = scalars_pandas_df["bool_col"]
pd_result = scalars_pandas_df[pd_bool_series]

assert 6 == capped_unfiltered_batches.total_rows
assert len(pd_result) == filtered_batches.total_rows
assert_pandas_df_equal(pd.concat(filtered_batches), pd_result)


def test_assign_new_column(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
kwargs = {"new_col": 2}
Expand Down