Skip to content

Commit ce1aa67

Browse files
perf: Rechunk result pages client side
1 parent f3fd7e2 commit ce1aa67

File tree

8 files changed

+182
-43
lines changed

8 files changed

+182
-43
lines changed

‎bigframes/core/blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,10 @@ def to_pandas_batches(
595595
self.expr,
596596
ordered=True,
597597
use_explicit_destination=allow_large_results,
598-
page_size=page_size,
599-
max_results=max_results,
600598
)
601-
for df in execute_result.to_pandas_batches():
599+
for df in execute_result.to_pandas_batches(
600+
page_size=page_size, max_results=max_results
601+
):
602602
self._copy_index_to_pandas(df)
603603
if squeeze:
604604
yield df.squeeze(axis=1)

‎bigframes/core/pyarrow_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Iterable, Iterator
15+
16+
import pyarrow as pa
17+
18+
19+
class BatchBuffer:
20+
"""
21+
FIFO buffer of pyarrow Record batches
22+
23+
Not thread-safe.
24+
"""
25+
26+
def __init__(self):
27+
self._buffer: list[pa.RecordBatch] = []
28+
self._buffer_size: int = 0
29+
30+
def __len__(self):
31+
return self._buffer_size
32+
33+
def append_batch(self, batch: pa.RecordBatch) -> None:
34+
self._buffer.append(batch)
35+
self._buffer_size += batch.num_rows
36+
37+
def take_as_batches(self, n: int) -> tuple[pa.RecordBatch, ...]:
38+
if n > len(self):
39+
raise ValueError(f"Cannot take {n} rows, only {len(self)} rows in buffer.")
40+
rows_taken = 0
41+
sub_batches: list[pa.RecordBatch] = []
42+
while rows_taken < n:
43+
batch = self._buffer.pop(0)
44+
if batch.num_rows > (n - rows_taken):
45+
sub_batches.append(batch.slice(length=n - rows_taken))
46+
self._buffer.insert(0, batch.slice(offset=n - rows_taken))
47+
rows_taken += n - rows_taken
48+
else:
49+
sub_batches.append(batch)
50+
rows_taken += batch.num_rows
51+
52+
self._buffer_size -= n
53+
return tuple(sub_batches)
54+
55+
def take_rechunked(self, n: int) -> pa.RecordBatch:
56+
return (
57+
pa.Table.from_batches(self.take_as_batches(n))
58+
.combine_chunks()
59+
.to_batches()[0]
60+
)
61+
62+
63+
def chunk_by_row_count(
64+
batches: Iterable[pa.RecordBatch], page_size: int
65+
) -> Iterator[tuple[pa.RecordBatch, ...]]:
66+
buffer = BatchBuffer()
67+
for batch in batches:
68+
buffer.append_batch(batch)
69+
while len(buffer) >= page_size:
70+
yield buffer.take_as_batches(page_size)
71+
72+
# emit final page, maybe smaller
73+
if len(buffer) > 0:
74+
yield buffer.take_as_batches(len(buffer))
75+
76+
77+
def truncate_pyarrow_iterable(
78+
batches: Iterable[pa.RecordBatch], max_results: int
79+
) -> Iterator[pa.RecordBatch]:
80+
total_yielded = 0
81+
for batch in batches:
82+
if batch.num_rows >= (max_results - total_yielded):
83+
yield batch.slice(length=max_results - total_yielded)
84+
return
85+
else:
86+
yield batch
87+
total_yielded += batch.num_rows

‎bigframes/session/_io/bigquery/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@ def start_query_with_client(
222222
job_config: bigquery.job.QueryJobConfig,
223223
location: Optional[str] = None,
224224
project: Optional[str] = None,
225-
max_results: Optional[int] = None,
226-
page_size: Optional[int] = None,
227225
timeout: Optional[float] = None,
228226
api_name: Optional[str] = None,
229227
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
@@ -244,8 +242,6 @@ def start_query_with_client(
244242
location=location,
245243
project=project,
246244
api_timeout=timeout,
247-
page_size=page_size,
248-
max_results=max_results,
249245
)
250246
if metrics is not None:
251247
metrics.count_job_stats(row_iterator=results_iterator)
@@ -267,14 +263,10 @@ def start_query_with_client(
267263
if opts.progress_bar is not None and not query_job.configuration.dry_run:
268264
results_iterator = formatting_helpers.wait_for_query_job(
269265
query_job,
270-
max_results=max_results,
271266
progress_bar=opts.progress_bar,
272-
page_size=page_size,
273267
)
274268
else:
275-
results_iterator = query_job.result(
276-
max_results=max_results, page_size=page_size
277-
)
269+
results_iterator = query_job.result()
278270

279271
if metrics is not None:
280272
metrics.count_job_stats(query_job=query_job)

‎bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def execute(
106106
*,
107107
ordered: bool = True,
108108
use_explicit_destination: Optional[bool] = None,
109-
page_size: Optional[int] = None,
110-
max_results: Optional[int] = None,
111109
) -> executor.ExecuteResult:
112110
if use_explicit_destination is None:
113111
use_explicit_destination = bigframes.options.bigquery.allow_large_results
@@ -127,8 +125,6 @@ def execute(
127125
return self._execute_plan(
128126
plan,
129127
ordered=ordered,
130-
page_size=page_size,
131-
max_results=max_results,
132128
destination=destination_table,
133129
)
134130

@@ -290,8 +286,6 @@ def _run_execute_query(
290286
sql: str,
291287
job_config: Optional[bq_job.QueryJobConfig] = None,
292288
api_name: Optional[str] = None,
293-
page_size: Optional[int] = None,
294-
max_results: Optional[int] = None,
295289
query_with_job: bool = True,
296290
) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]:
297291
"""
@@ -312,8 +306,6 @@ def _run_execute_query(
312306
sql,
313307
job_config=job_config,
314308
api_name=api_name,
315-
max_results=max_results,
316-
page_size=page_size,
317309
metrics=self.metrics,
318310
query_with_job=query_with_job,
319311
)
@@ -488,16 +480,13 @@ def _execute_plan(
488480
self,
489481
plan: nodes.BigFrameNode,
490482
ordered: bool,
491-
page_size: Optional[int] = None,
492-
max_results: Optional[int] = None,
493483
destination: Optional[bq_table.TableReference] = None,
494484
peek: Optional[int] = None,
495485
):
496486
"""Just execute whatever plan as is, without further caching or decomposition."""
497487

498488
# First try to execute fast-paths
499-
# TODO: Allow page_size and max_results by rechunking/truncating results
500-
if (not page_size) and (not max_results) and (not destination) and (not peek):
489+
if (not destination) and (not peek):
501490
for semi_executor in self._semi_executors:
502491
maybe_result = semi_executor.execute(plan, ordered=ordered)
503492
if maybe_result:
@@ -513,20 +502,12 @@ def _execute_plan(
513502
iterator, query_job = self._run_execute_query(
514503
sql=sql,
515504
job_config=job_config,
516-
page_size=page_size,
517-
max_results=max_results,
518505
query_with_job=(destination is not None),
519506
)
520507

521508
# Though we provide the read client, iterator may or may not use it based on what is efficient for the result
522509
def iterator_supplier():
523-
# Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154
524-
if iterator._page_size is not None or iterator.max_results is not None:
525-
return iterator.to_arrow_iterable(bqstorage_client=None)
526-
else:
527-
return iterator.to_arrow_iterable(
528-
bqstorage_client=self.bqstoragereadclient
529-
)
510+
return iterator.to_arrow_iterable(bqstorage_client=self.bqstoragereadclient)
530511

531512
if query_job:
532513
size_bytes = self.bqclient.get_table(query_job.destination).num_bytes

‎bigframes/session/executor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pyarrow
2626

2727
import bigframes.core
28+
from bigframes.core import pyarrow_utils
2829
import bigframes.core.schema
2930
import bigframes.session._io.pandas as io_pandas
3031

@@ -55,10 +56,28 @@ def to_arrow_table(self) -> pyarrow.Table:
5556
def to_pandas(self) -> pd.DataFrame:
5657
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self.schema)
5758

58-
def to_pandas_batches(self) -> Iterator[pd.DataFrame]:
59+
def to_pandas_batches(
60+
self, page_size: Optional[int] = None, max_results: Optional[int] = None
61+
) -> Iterator[pd.DataFrame]:
62+
assert (page_size is None) or (page_size > 0)
63+
assert (max_results is None) or (max_results > 0)
64+
batch_iter: Iterator[
65+
Union[pyarrow.Table, pyarrow.RecordBatch]
66+
] = self.arrow_batches()
67+
if max_results is not None:
68+
batch_iter = pyarrow_utils.truncate_pyarrow_iterable(
69+
batch_iter, max_results
70+
)
71+
72+
if page_size is not None:
73+
batches_iter = pyarrow_utils.chunk_by_row_count(batch_iter, page_size)
74+
batch_iter = map(
75+
lambda batches: pyarrow.Table.from_batches(batches), batches_iter
76+
)
77+
5978
yield from map(
6079
functools.partial(io_pandas.arrow_to_pandas, schema=self.schema),
61-
self.arrow_batches(),
80+
batch_iter,
6281
)
6382

6483
def to_py_scalar(self):
@@ -96,8 +115,6 @@ def execute(
96115
*,
97116
ordered: bool = True,
98117
use_explicit_destination: Optional[bool] = False,
99-
page_size: Optional[int] = None,
100-
max_results: Optional[int] = None,
101118
) -> ExecuteResult:
102119
"""
103120
Execute the ArrayValue, storing the result to a temporary session-owned table.

‎bigframes/session/loader.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,6 @@ def _start_query(
809809
self,
810810
sql: str,
811811
job_config: Optional[google.cloud.bigquery.QueryJobConfig] = None,
812-
max_results: Optional[int] = None,
813812
timeout: Optional[float] = None,
814813
api_name: Optional[str] = None,
815814
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
@@ -828,7 +827,6 @@ def _start_query(
828827
self._bqclient,
829828
sql,
830829
job_config=job_config,
831-
max_results=max_results,
832830
timeout=timeout,
833831
api_name=api_name,
834832
)

‎tests/unit/core/test_pyarrow_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import numpy as np
18+
import pyarrow as pa
19+
import pytest
20+
21+
from bigframes.core import pyarrow_utils
22+
23+
PA_TABLE = pa.table({f"col_{i}": np.random.rand(1000) for i in range(10)})
24+
25+
# 17, 3, 929 coprime
26+
N = 17
27+
MANY_SMALL_BATCHES = PA_TABLE.to_batches(max_chunksize=3)
28+
FEW_BIG_BATCHES = PA_TABLE.to_batches(max_chunksize=929)
29+
30+
31+
@pytest.mark.parametrize(
32+
["batches", "page_size"],
33+
[
34+
(MANY_SMALL_BATCHES, N),
35+
(FEW_BIG_BATCHES, N),
36+
],
37+
)
38+
def test_chunk_by_row_count(batches, page_size):
39+
results = list(pyarrow_utils.chunk_by_row_count(batches, page_size=page_size))
40+
41+
for i, batches in enumerate(results):
42+
if i != len(results) - 1:
43+
assert sum(map(lambda x: x.num_rows, batches)) == page_size
44+
else:
45+
# final page can be smaller
46+
assert sum(map(lambda x: x.num_rows, batches)) <= page_size
47+
48+
reconstructed = pa.Table.from_batches(itertools.chain.from_iterable(results))
49+
assert reconstructed.equals(PA_TABLE)
50+
51+
52+
@pytest.mark.parametrize(
53+
["batches", "max_rows"],
54+
[
55+
(MANY_SMALL_BATCHES, N),
56+
(FEW_BIG_BATCHES, N),
57+
],
58+
)
59+
def test_truncate_pyarrow_iterable(batches, max_rows):
60+
results = list(
61+
pyarrow_utils.truncate_pyarrow_iterable(batches, max_results=max_rows)
62+
)
63+
64+
reconstructed = pa.Table.from_batches(results)
65+
assert reconstructed.equals(PA_TABLE.slice(length=max_rows))

‎tests/unit/session/test_io_bigquery.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,11 @@ def test_add_and_trim_labels_length_limit_met():
199199

200200

201201
@pytest.mark.parametrize(
202-
("max_results", "timeout", "api_name"),
203-
[(None, None, None), (100, 30.0, "test_api")],
202+
("timeout", "api_name"),
203+
[(None, None), (30.0, "test_api")],
204204
)
205205
def test_start_query_with_client_labels_length_limit_met(
206-
mock_bq_client, max_results, timeout, api_name
206+
mock_bq_client, timeout, api_name
207207
):
208208
sql = "select * from abc"
209209
cur_labels = {
@@ -230,7 +230,6 @@ def test_start_query_with_client_labels_length_limit_met(
230230
mock_bq_client,
231231
sql,
232232
job_config,
233-
max_results=max_results,
234233
timeout=timeout,
235234
api_name=api_name,
236235
)

0 commit comments

Comments
 (0)