Skip to content

Commit af2245d

Browse files
committed
feat: support names parameter in read_csv for bigquery engine
1 parent 30a6237 commit af2245d

File tree

6 files changed

+204
-48
lines changed

6 files changed

+204
-48
lines changed

‎bigframes/core/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def get_axis_number(axis: typing.Union[str, int]) -> typing.Literal[0, 1]:
4141
raise ValueError(f"Not a valid axis: {axis}")
4242

4343

44-
def is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:
45-
return pd.api.types.is_list_like(obj)
44+
def is_list_like(
45+
obj: typing.Any, allow_sets: bool = True
46+
) -> typing_extensions.TypeGuard[typing.Sequence]:
47+
return pd.api.types.is_list_like(obj, allow_sets=allow_sets)
4648

4749

4850
def is_dict_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Mapping]:

‎bigframes/session/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
from collections import abc
1920
import datetime
2021
import logging
2122
import os
@@ -576,7 +577,7 @@ def read_gbq_table(
576577
columns = col_order
577578

578579
return self._loader.read_gbq_table(
579-
query=query,
580+
table_id=query,
580581
index_col=index_col,
581582
columns=columns,
582583
max_results=max_results,
@@ -960,14 +961,21 @@ def _read_csv_w_bigquery_engine(
960961
native CSV loading capabilities, making it suitable for large datasets
961962
that may not fit into local memory.
962963
"""
963-
964-
if any(param is not None for param in (dtype, names)):
965-
not_supported = ("dtype", "names")
964+
if dtype is not None:
966965
raise NotImplementedError(
967-
f"BigQuery engine does not support these arguments: {not_supported}. "
966+
f"BigQuery engine does not support the `dtype` argument."
968967
f"{constants.FEEDBACK_LINK}"
969968
)
970969

970+
if names is not None:
971+
if len(names) != len(set(names)):
972+
raise ValueError("Duplicated names are not allowed.")
973+
if not (
974+
bigframes.core.utils.is_list_like(names, allow_sets=False)
975+
or isinstance(names, abc.KeysView)
976+
):
977+
raise ValueError("Names should be an ordered collection.")
978+
971979
if index_col is True:
972980
raise ValueError("The value of index_col couldn't be 'True'")
973981

@@ -1011,11 +1019,9 @@ def _read_csv_w_bigquery_engine(
10111019
elif header > 0:
10121020
job_config.skip_leading_rows = header + 1
10131021

1014-
return self._loader.read_bigquery_load_job(
1015-
filepath_or_buffer,
1016-
job_config=job_config,
1017-
index_col=index_col,
1018-
columns=columns,
1022+
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
1023+
return self._loader.read_gbq_table(
1024+
table_id, index_col=index_col, columns=columns, names=names
10191025
)
10201026

10211027
def read_pickle(
@@ -1056,8 +1062,8 @@ def read_parquet(
10561062
job_config = bigquery.LoadJobConfig()
10571063
job_config.source_format = bigquery.SourceFormat.PARQUET
10581064
job_config.labels = {"bigframes-api": "read_parquet"}
1059-
1060-
return self._loader.read_bigquery_load_job(path, job_config=job_config)
1065+
table_id = self._loader.load_file(path, job_config=job_config)
1066+
return self._loader.read_gbq_table(table_id)
10611067
else:
10621068
if "*" in path:
10631069
raise ValueError(
@@ -1128,10 +1134,8 @@ def read_json(
11281134
job_config.encoding = encoding
11291135
job_config.labels = {"bigframes-api": "read_json"}
11301136

1131-
return self._loader.read_bigquery_load_job(
1132-
path_or_buf,
1133-
job_config=job_config,
1134-
)
1137+
table_id = self._loader.load_file(path_or_buf, job_config=job_config)
1138+
return self._loader.read_gbq_table(table_id)
11351139
else:
11361140
if any(arg in kwargs for arg in ("chunksize", "iterator")):
11371141
raise NotImplementedError(

‎bigframes/session/_io/bigquery/read_gbq_table.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def get_index_cols(
235235
| Iterable[int]
236236
| int
237237
| bigframes.enums.DefaultIndexKind,
238+
*,
239+
names: Optional[Iterable[str]] = None,
238240
) -> List[str]:
239241
"""
240242
If we can get a total ordering from the table, such as via primary key
@@ -245,6 +247,14 @@ def get_index_cols(
245247
# Transform index_col -> index_cols so we have a variable that is
246248
# always a list of column names (possibly empty).
247249
schema_len = len(table.schema)
250+
251+
# If the `names` is provided, the index_col provided by the user is the new
252+
# name, so we need to rename it to the original name in the table schema.
253+
renamed_schema: Optional[Dict[str, str]] = None
254+
if names is not None:
255+
assert len(list(names)) == schema_len
256+
renamed_schema = {name: field.name for name, field in zip(names, table.schema)}
257+
248258
index_cols: List[str] = []
249259
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
250260
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
@@ -261,6 +271,8 @@ def get_index_cols(
261271
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
262272
)
263273
elif isinstance(index_col, str):
274+
if renamed_schema is not None:
275+
index_col = renamed_schema.get(index_col, index_col)
264276
index_cols = [index_col]
265277
elif isinstance(index_col, int):
266278
if not 0 <= index_col < schema_len:
@@ -272,6 +284,8 @@ def get_index_cols(
272284
elif isinstance(index_col, Iterable):
273285
for item in index_col:
274286
if isinstance(item, str):
287+
if renamed_schema is not None:
288+
item = renamed_schema.get(item, item)
275289
index_cols.append(item)
276290
elif isinstance(item, int):
277291
if not 0 <= item < schema_len:

‎bigframes/session/loader.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,15 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
348348

349349
def read_gbq_table(
350350
self,
351-
query: str,
351+
table_id: str,
352352
*,
353353
index_col: Iterable[str]
354354
| str
355355
| Iterable[int]
356356
| int
357357
| bigframes.enums.DefaultIndexKind = (),
358358
columns: Iterable[str] = (),
359+
names: Optional[Iterable[str]] = None,
359360
max_results: Optional[int] = None,
360361
api_name: str = "read_gbq_table",
361362
use_cache: bool = True,
@@ -375,7 +376,7 @@ def read_gbq_table(
375376
)
376377

377378
table_ref = google.cloud.bigquery.table.TableReference.from_string(
378-
query, default_project=self._bqclient.project
379+
table_id, default_project=self._bqclient.project
379380
)
380381

381382
columns = list(columns)
@@ -411,12 +412,37 @@ def read_gbq_table(
411412
f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?"
412413
)
413414

415+
# TODO: check if columns is not None
416+
if names is not None:
417+
len_names = len(list(names))
418+
len_columns = len(table.schema)
419+
if len_names > len_columns:
420+
raise ValueError(
421+
f"Too many columns specified: expected {len_columns}"
422+
f" and found {len_names}"
423+
)
424+
elif len_names < len_columns:
425+
if (
426+
isinstance(index_col, bigframes.enums.DefaultIndexKind)
427+
or index_col != ()
428+
):
429+
raise KeyError(
430+
"When providing both `index_col` and `names`, ensure the "
431+
"number of `names` matches the number of columns in your "
432+
"data."
433+
)
434+
index_col = range(len_columns - len_names)
435+
names = [
436+
field.name for field in table.schema[: len_columns - len_names]
437+
] + list(names)
438+
414439
# Converting index_col into a list of column names requires
415440
# the table metadata because we might use the primary keys
416441
# when constructing the index.
417442
index_cols = bf_read_gbq_table.get_index_cols(
418443
table=table,
419444
index_col=index_col,
445+
names=names,
420446
)
421447
_check_column_duplicates(index_cols, columns)
422448

@@ -443,15 +469,15 @@ def read_gbq_table(
443469
# TODO(b/338419730): We don't need to fallback to a query for wildcard
444470
# tables if we allow some non-determinism when time travel isn't supported.
445471
if max_results is not None or bf_io_bigquery.is_table_with_wildcard_suffix(
446-
query
472+
table_id
447473
):
448474
# TODO(b/338111344): If we are running a query anyway, we might as
449475
# well generate ROW_NUMBER() at the same time.
450476
all_columns: Iterable[str] = (
451477
itertools.chain(index_cols, columns) if columns else ()
452478
)
453479
query = bf_io_bigquery.to_query(
454-
query,
480+
table_id,
455481
columns=all_columns,
456482
sql_predicate=bf_io_bigquery.compile_filters(filters)
457483
if filters
@@ -561,6 +587,15 @@ def read_gbq_table(
561587
index_names = [None]
562588

563589
value_columns = [col for col in array_value.column_ids if col not in index_cols]
590+
if names is not None:
591+
renamed_cols: Dict[str, str] = {
592+
col: new_name for col, new_name in zip(array_value.column_ids, names)
593+
}
594+
index_names = [
595+
renamed_cols.get(index_col, index_col) for index_col in index_cols
596+
]
597+
value_columns = [renamed_cols.get(col, col) for col in value_columns]
598+
564599
block = blocks.Block(
565600
array_value,
566601
index_columns=index_cols,
@@ -576,18 +611,12 @@ def read_gbq_table(
576611
df.sort_index()
577612
return df
578613

579-
def read_bigquery_load_job(
614+
def load_file(
580615
self,
581616
filepath_or_buffer: str | IO["bytes"],
582617
*,
583618
job_config: bigquery.LoadJobConfig,
584-
index_col: Iterable[str]
585-
| str
586-
| Iterable[int]
587-
| int
588-
| bigframes.enums.DefaultIndexKind = (),
589-
columns: Iterable[str] = (),
590-
) -> dataframe.DataFrame:
619+
) -> str:
591620
# Need to create session table beforehand
592621
table = self._storage_manager.create_temp_table(_PLACEHOLDER_SCHEMA)
593622
# but, we just overwrite the placeholder schema immediately with the load job
@@ -615,16 +644,7 @@ def read_bigquery_load_job(
615644

616645
self._start_generic_job(load_job)
617646
table_id = f"{table.project}.{table.dataset_id}.{table.table_id}"
618-
619-
# The BigQuery REST API for tables.get doesn't take a session ID, so we
620-
# can't get the schema for a temp table that way.
621-
622-
return self.read_gbq_table(
623-
query=table_id,
624-
index_col=index_col,
625-
columns=columns,
626-
api_name="read_gbq_table",
627-
)
647+
return table_id
628648

629649
def read_gbq_query(
630650
self,

0 commit comments

Comments
 (0)