Skip to content

Commit e3f5e65

Browse files
feat: Add total_rows property to pandas batches iterator (#1888)
1 parent f63caf2 commit e3f5e65

File tree

3 files changed

+72
-45
lines changed

3 files changed

+72
-45
lines changed

‎bigframes/core/blocks.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,17 @@
2929
import random
3030
import textwrap
3131
import typing
32-
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
32+
from typing import (
33+
Iterable,
34+
Iterator,
35+
List,
36+
Literal,
37+
Mapping,
38+
Optional,
39+
Sequence,
40+
Tuple,
41+
Union,
42+
)
3343
import warnings
3444

3545
import bigframes_vendored.constants as constants
@@ -87,14 +97,22 @@
8797
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
8898

8999

90-
class BlockHolder(typing.Protocol):
100+
@dataclasses.dataclass
101+
class PandasBatches(Iterator[pd.DataFrame]):
91102
"""Interface for mutable objects with state represented by a block value object."""
92103

93-
def _set_block(self, block: Block):
94-
"""Set the underlying block value of the object"""
104+
def __init__(
105+
self, pandas_batches: Iterator[pd.DataFrame], total_rows: Optional[int] = 0
106+
):
107+
self._dataframes: Iterator[pd.DataFrame] = pandas_batches
108+
self._total_rows: Optional[int] = total_rows
109+
110+
@property
111+
def total_rows(self) -> Optional[int]:
112+
return self._total_rows
95113

96-
def _get_block(self) -> Block:
97-
"""Get the underlying block value of the object"""
114+
def __next__(self) -> pd.DataFrame:
115+
return next(self._dataframes)
98116

99117

100118
@dataclasses.dataclass()
@@ -599,8 +617,7 @@ def try_peek(
599617
self.expr, n, use_explicit_destination=allow_large_results
600618
)
601619
df = result.to_pandas()
602-
self._copy_index_to_pandas(df)
603-
return df
620+
return self._copy_index_to_pandas(df)
604621
else:
605622
return None
606623

@@ -609,8 +626,7 @@ def to_pandas_batches(
609626
page_size: Optional[int] = None,
610627
max_results: Optional[int] = None,
611628
allow_large_results: Optional[bool] = None,
612-
squeeze: Optional[bool] = False,
613-
):
629+
) -> Iterator[pd.DataFrame]:
614630
"""Download results one message at a time.
615631
616632
page_size and max_results determine the size and number of batches,
@@ -621,43 +637,43 @@ def to_pandas_batches(
621637
use_explicit_destination=allow_large_results,
622638
)
623639

624-
total_batches = 0
625-
for df in execute_result.to_pandas_batches(
626-
page_size=page_size, max_results=max_results
627-
):
628-
total_batches += 1
629-
self._copy_index_to_pandas(df)
630-
if squeeze:
631-
yield df.squeeze(axis=1)
632-
else:
633-
yield df
634-
635640
# To reduce the number of edge cases to consider when working with the
636641
# results of this, always return at least one DataFrame. See:
637642
# b/428918844.
638-
if total_batches == 0:
639-
df = pd.DataFrame(
640-
{
641-
col: pd.Series([], dtype=self.expr.get_column_type(col))
642-
for col in itertools.chain(self.value_columns, self.index_columns)
643-
}
644-
)
645-
self._copy_index_to_pandas(df)
646-
yield df
643+
empty_val = pd.DataFrame(
644+
{
645+
col: pd.Series([], dtype=self.expr.get_column_type(col))
646+
for col in itertools.chain(self.value_columns, self.index_columns)
647+
}
648+
)
649+
dfs = map(
650+
lambda a: a[0],
651+
itertools.zip_longest(
652+
execute_result.to_pandas_batches(page_size, max_results),
653+
[0],
654+
fillvalue=empty_val,
655+
),
656+
)
657+
dfs = iter(map(self._copy_index_to_pandas, dfs))
647658

648-
def _copy_index_to_pandas(self, df: pd.DataFrame):
649-
"""Set the index on pandas DataFrame to match this block.
659+
total_rows = execute_result.total_rows
660+
if (total_rows is not None) and (max_results is not None):
661+
total_rows = min(total_rows, max_results)
650662

651-
Warning: This method modifies ``df`` inplace.
652-
"""
663+
return PandasBatches(dfs, total_rows)
664+
665+
def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
666+
"""Set the index on pandas DataFrame to match this block."""
653667
# Note: If BigQuery DataFrame has null index, a default one will be created for the local materialization.
668+
new_df = df.copy()
654669
if len(self.index_columns) > 0:
655-
df.set_index(list(self.index_columns), inplace=True)
670+
new_df.set_index(list(self.index_columns), inplace=True)
656671
# Pandas names is annotated as list[str] rather than the more
657672
# general Sequence[Label] that BigQuery DataFrames has.
658673
# See: https://github.com/pandas-dev/pandas-stubs/issues/804
659-
df.index.names = self.index.names # type: ignore
660-
df.columns = self.column_labels
674+
new_df.index.names = self.index.names # type: ignore
675+
new_df.columns = self.column_labels
676+
return new_df
661677

662678
def _materialize_local(
663679
self, materialize_options: MaterializationOptions = MaterializationOptions()
@@ -724,9 +740,7 @@ def _materialize_local(
724740
)
725741
else:
726742
df = execute_result.to_pandas()
727-
self._copy_index_to_pandas(df)
728-
729-
return df, execute_result.query_job
743+
return self._copy_index_to_pandas(df), execute_result.query_job
730744

731745
def _downsample(
732746
self, total_rows: int, sampling_method: str, fraction: float, random_state
@@ -1591,8 +1605,7 @@ def retrieve_repr_request_results(
15911605
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()
15921606

15931607
head_df = head_result.to_pandas()
1594-
self._copy_index_to_pandas(head_df)
1595-
return head_df, row_count, head_result.query_job
1608+
return self._copy_index_to_pandas(head_df), row_count, head_result.query_job
15961609

15971610
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
15981611
expr, result_id = self._expr.promote_offsets()

‎bigframes/series.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,13 +648,12 @@ def to_pandas_batches(
648648
form the original Series. Results stream from bigquery,
649649
see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.table.RowIterator#google_cloud_bigquery_table_RowIterator_to_arrow_iterable
650650
"""
651-
df = self._block.to_pandas_batches(
651+
batches = self._block.to_pandas_batches(
652652
page_size=page_size,
653653
max_results=max_results,
654654
allow_large_results=allow_large_results,
655-
squeeze=True,
656655
)
657-
return df
656+
return map(lambda df: cast(pandas.Series, df.squeeze(1)), batches)
658657

659658
def _compute_dry_run(self) -> bigquery.QueryJob:
660659
_, query_job = self._block._compute_dry_run((self._value_column,))

‎tests/system/small/test_dataframe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,21 @@ def test_filter_df(scalars_dfs):
871871
assert_pandas_df_equal(bf_result, pd_result)
872872

873873

874+
def test_df_to_pandas_batches(scalars_dfs):
875+
scalars_df, scalars_pandas_df = scalars_dfs
876+
877+
capped_unfiltered_batches = scalars_df.to_pandas_batches(page_size=2, max_results=6)
878+
bf_bool_series = scalars_df["bool_col"]
879+
filtered_batches = scalars_df[bf_bool_series].to_pandas_batches()
880+
881+
pd_bool_series = scalars_pandas_df["bool_col"]
882+
pd_result = scalars_pandas_df[pd_bool_series]
883+
884+
assert 6 == capped_unfiltered_batches.total_rows
885+
assert len(pd_result) == filtered_batches.total_rows
886+
assert_pandas_df_equal(pd.concat(filtered_batches), pd_result)
887+
888+
874889
def test_assign_new_column(scalars_dfs):
875890
scalars_df, scalars_pandas_df = scalars_dfs
876891
kwargs = {"new_col": 2}

0 commit comments

Comments
 (0)