Skip to content

perf: Rechunk result pages client side #1680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,10 @@ def to_pandas_batches(
self.expr,
ordered=True,
use_explicit_destination=allow_large_results,
page_size=page_size,
max_results=max_results,
)
for df in execute_result.to_pandas_batches():
for df in execute_result.to_pandas_batches(
page_size=page_size, max_results=max_results
):
self._copy_index_to_pandas(df)
if squeeze:
yield df.squeeze(axis=1)
Expand Down
87 changes: 87 additions & 0 deletions bigframes/core/pyarrow_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Iterator

import pyarrow as pa


class BatchBuffer:
"""
FIFO buffer of pyarrow Record batches

Not thread-safe.
"""

def __init__(self):
self._buffer: list[pa.RecordBatch] = []
self._buffer_size: int = 0

def __len__(self):
return self._buffer_size

def append_batch(self, batch: pa.RecordBatch) -> None:
self._buffer.append(batch)
self._buffer_size += batch.num_rows

def take_as_batches(self, n: int) -> tuple[pa.RecordBatch, ...]:
if n > len(self):
raise ValueError(f"Cannot take {n} rows, only {len(self)} rows in buffer.")
rows_taken = 0
sub_batches: list[pa.RecordBatch] = []
while rows_taken < n:
batch = self._buffer.pop(0)
if batch.num_rows > (n - rows_taken):
sub_batches.append(batch.slice(length=n - rows_taken))
self._buffer.insert(0, batch.slice(offset=n - rows_taken))
rows_taken += n - rows_taken
else:
sub_batches.append(batch)
rows_taken += batch.num_rows

self._buffer_size -= n
return tuple(sub_batches)

def take_rechunked(self, n: int) -> pa.RecordBatch:
return (
pa.Table.from_batches(self.take_as_batches(n))
.combine_chunks()
.to_batches()[0]
)


def chunk_by_row_count(
batches: Iterable[pa.RecordBatch], page_size: int
) -> Iterator[tuple[pa.RecordBatch, ...]]:
buffer = BatchBuffer()
for batch in batches:
buffer.append_batch(batch)
while len(buffer) >= page_size:
yield buffer.take_as_batches(page_size)

# emit final page, maybe smaller
if len(buffer) > 0:
yield buffer.take_as_batches(len(buffer))


def truncate_pyarrow_iterable(
batches: Iterable[pa.RecordBatch], max_results: int
) -> Iterator[pa.RecordBatch]:
total_yielded = 0
for batch in batches:
if batch.num_rows >= (max_results - total_yielded):
yield batch.slice(length=max_results - total_yielded)
return
else:
yield batch
total_yielded += batch.num_rows
10 changes: 1 addition & 9 deletions bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def start_query_with_client(
job_config: bigquery.job.QueryJobConfig,
location: Optional[str] = None,
project: Optional[str] = None,
max_results: Optional[int] = None,
page_size: Optional[int] = None,
timeout: Optional[float] = None,
api_name: Optional[str] = None,
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
Expand All @@ -244,8 +242,6 @@ def start_query_with_client(
location=location,
project=project,
api_timeout=timeout,
page_size=page_size,
max_results=max_results,
)
if metrics is not None:
metrics.count_job_stats(row_iterator=results_iterator)
Expand All @@ -267,14 +263,10 @@ def start_query_with_client(
if opts.progress_bar is not None and not query_job.configuration.dry_run:
results_iterator = formatting_helpers.wait_for_query_job(
query_job,
max_results=max_results,
progress_bar=opts.progress_bar,
page_size=page_size,
)
else:
results_iterator = query_job.result(
max_results=max_results, page_size=page_size
)
results_iterator = query_job.result()

if metrics is not None:
metrics.count_job_stats(query_job=query_job)
Expand Down
23 changes: 2 additions & 21 deletions bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def execute(
*,
ordered: bool = True,
use_explicit_destination: Optional[bool] = None,
page_size: Optional[int] = None,
max_results: Optional[int] = None,
) -> executor.ExecuteResult:
if use_explicit_destination is None:
use_explicit_destination = bigframes.options.bigquery.allow_large_results
Expand All @@ -127,8 +125,6 @@ def execute(
return self._execute_plan(
plan,
ordered=ordered,
page_size=page_size,
max_results=max_results,
destination=destination_table,
)

Expand Down Expand Up @@ -281,8 +277,6 @@ def _run_execute_query(
sql: str,
job_config: Optional[bq_job.QueryJobConfig] = None,
api_name: Optional[str] = None,
page_size: Optional[int] = None,
max_results: Optional[int] = None,
query_with_job: bool = True,
) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]:
"""
Expand All @@ -303,8 +297,6 @@ def _run_execute_query(
sql,
job_config=job_config,
api_name=api_name,
max_results=max_results,
page_size=page_size,
metrics=self.metrics,
query_with_job=query_with_job,
)
Expand Down Expand Up @@ -479,16 +471,13 @@ def _execute_plan(
self,
plan: nodes.BigFrameNode,
ordered: bool,
page_size: Optional[int] = None,
max_results: Optional[int] = None,
destination: Optional[bq_table.TableReference] = None,
peek: Optional[int] = None,
):
"""Just execute whatever plan as is, without further caching or decomposition."""

# First try to execute fast-paths
# TODO: Allow page_size and max_results by rechunking/truncating results
if (not page_size) and (not max_results) and (not destination) and (not peek):
if (not destination) and (not peek):
for semi_executor in self._semi_executors:
maybe_result = semi_executor.execute(plan, ordered=ordered)
if maybe_result:
Expand All @@ -504,20 +493,12 @@ def _execute_plan(
iterator, query_job = self._run_execute_query(
sql=sql,
job_config=job_config,
page_size=page_size,
max_results=max_results,
query_with_job=(destination is not None),
)

# Though we provide the read client, iterator may or may not use it based on what is efficient for the result
def iterator_supplier():
# Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154
if iterator._page_size is not None or iterator.max_results is not None:
return iterator.to_arrow_iterable(bqstorage_client=None)
else:
return iterator.to_arrow_iterable(
bqstorage_client=self.bqstoragereadclient
)
return iterator.to_arrow_iterable(bqstorage_client=self.bqstoragereadclient)

if query_job:
size_bytes = self.bqclient.get_table(query_job.destination).num_bytes
Expand Down
25 changes: 21 additions & 4 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pyarrow

import bigframes.core
from bigframes.core import pyarrow_utils
import bigframes.core.schema
import bigframes.session._io.pandas as io_pandas

Expand Down Expand Up @@ -55,10 +56,28 @@ def to_arrow_table(self) -> pyarrow.Table:
def to_pandas(self) -> pd.DataFrame:
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self.schema)

def to_pandas_batches(self) -> Iterator[pd.DataFrame]:
def to_pandas_batches(
self, page_size: Optional[int] = None, max_results: Optional[int] = None
) -> Iterator[pd.DataFrame]:
assert (page_size is None) or (page_size > 0)
assert (max_results is None) or (max_results > 0)
batch_iter: Iterator[
Union[pyarrow.Table, pyarrow.RecordBatch]
] = self.arrow_batches()
if max_results is not None:
batch_iter = pyarrow_utils.truncate_pyarrow_iterable(
batch_iter, max_results
)

if page_size is not None:
batches_iter = pyarrow_utils.chunk_by_row_count(batch_iter, page_size)
batch_iter = map(
lambda batches: pyarrow.Table.from_batches(batches), batches_iter
)

yield from map(
functools.partial(io_pandas.arrow_to_pandas, schema=self.schema),
self.arrow_batches(),
batch_iter,
)

def to_py_scalar(self):
Expand Down Expand Up @@ -107,8 +126,6 @@ def execute(
*,
ordered: bool = True,
use_explicit_destination: Optional[bool] = False,
page_size: Optional[int] = None,
max_results: Optional[int] = None,
) -> ExecuteResult:
"""
Execute the ArrayValue, storing the result to a temporary session-owned table.
Expand Down
2 changes: 0 additions & 2 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,6 @@ def _start_query(
self,
sql: str,
job_config: Optional[google.cloud.bigquery.QueryJobConfig] = None,
max_results: Optional[int] = None,
timeout: Optional[float] = None,
api_name: Optional[str] = None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
Expand All @@ -925,7 +924,6 @@ def _start_query(
self._bqclient,
sql,
job_config=job_config,
max_results=max_results,
timeout=timeout,
api_name=api_name,
)
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/core/test_pyarrow_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import numpy as np
import pyarrow as pa
import pytest

from bigframes.core import pyarrow_utils

PA_TABLE = pa.table({f"col_{i}": np.random.rand(1000) for i in range(10)})

# 17, 3, 929 coprime
N = 17
MANY_SMALL_BATCHES = PA_TABLE.to_batches(max_chunksize=3)
FEW_BIG_BATCHES = PA_TABLE.to_batches(max_chunksize=929)


@pytest.mark.parametrize(
["batches", "page_size"],
[
(MANY_SMALL_BATCHES, N),
(FEW_BIG_BATCHES, N),
],
)
def test_chunk_by_row_count(batches, page_size):
results = list(pyarrow_utils.chunk_by_row_count(batches, page_size=page_size))

for i, batches in enumerate(results):
if i != len(results) - 1:
assert sum(map(lambda x: x.num_rows, batches)) == page_size
else:
# final page can be smaller
assert sum(map(lambda x: x.num_rows, batches)) <= page_size

reconstructed = pa.Table.from_batches(itertools.chain.from_iterable(results))
assert reconstructed.equals(PA_TABLE)


@pytest.mark.parametrize(
["batches", "max_rows"],
[
(MANY_SMALL_BATCHES, N),
(FEW_BIG_BATCHES, N),
],
)
def test_truncate_pyarrow_iterable(batches, max_rows):
results = list(
pyarrow_utils.truncate_pyarrow_iterable(batches, max_results=max_rows)
)

reconstructed = pa.Table.from_batches(results)
assert reconstructed.equals(PA_TABLE.slice(length=max_rows))
7 changes: 3 additions & 4 deletions tests/unit/session/test_io_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ def test_add_and_trim_labels_length_limit_met():


@pytest.mark.parametrize(
("max_results", "timeout", "api_name"),
[(None, None, None), (100, 30.0, "test_api")],
("timeout", "api_name"),
[(None, None), (30.0, "test_api")],
)
def test_start_query_with_client_labels_length_limit_met(
mock_bq_client, max_results, timeout, api_name
mock_bq_client, timeout, api_name
):
sql = "select * from abc"
cur_labels = {
Expand All @@ -230,7 +230,6 @@ def test_start_query_with_client_labels_length_limit_met(
mock_bq_client,
sql,
job_config,
max_results=max_results,
timeout=timeout,
api_name=api_name,
)
Expand Down