Skip to content

feat: enhance read_csv index_col parameter support #1631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,37 +961,13 @@ def _read_csv_w_bigquery_engine(
f"{constants.FEEDBACK_LINK}"
)

# TODO(b/338089659): Looks like we can relax this 1 column
# restriction if we check the contents of an iterable are strings
# not integers.
if (
# Empty tuples, None, and False are allowed and falsey.
index_col
and not isinstance(index_col, bigframes.enums.DefaultIndexKind)
and not isinstance(index_col, str)
):
raise NotImplementedError(
"BigQuery engine only supports a single column name for `index_col`, "
f"got: {repr(index_col)}. {constants.FEEDBACK_LINK}"
)
if index_col is True:
raise ValueError("The value of index_col couldn't be 'True'")

# None and False cannot be passed to read_gbq.
# TODO(b/338400133): When index_col is None, we should be using the
# first column of the CSV as the index to be compatible with the
# pandas engine. According to the pandas docs, only "False"
# indicates a default sequential index.
if not index_col:
if index_col is None or index_col is False:
index_col = ()

index_col = typing.cast(
Union[
Sequence[str], # Falsey values
bigframes.enums.DefaultIndexKind,
str,
],
index_col,
)

# usecols should only be an iterable of strings (column names) for use as columns in read_gbq.
columns: Tuple[Any, ...] = tuple()
if usecols is not None:
Expand Down
38 changes: 35 additions & 3 deletions bigframes/session/_io/bigquery/read_gbq_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ def _is_table_clustered_or_partitioned(

def get_index_cols(
table: bigquery.table.Table,
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind,
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind,
) -> List[str]:
"""
If we can get a total ordering from the table, such as via primary key
Expand All @@ -240,6 +244,8 @@ def get_index_cols(

# Transform index_col -> index_cols so we have a variable that is
# always a list of column names (possibly empty).
schema_len = len(table.schema)
index_cols: List[str] = []
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
# User has explicity asked for a default, sequential index.
Expand All @@ -255,9 +261,35 @@ def get_index_cols(
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
)
elif isinstance(index_col, str):
index_cols: List[str] = [index_col]
index_cols = [index_col]
elif isinstance(index_col, int):
if not 0 <= index_col < schema_len:
raise ValueError(
f"Integer index {index_col} is out of bounds "
f"for table with {schema_len} columns (must be >= 0 and < {schema_len})."
)
index_cols = [table.schema[index_col].name]
elif isinstance(index_col, Iterable):
for item in index_col:
if isinstance(item, str):
index_cols.append(item)
elif isinstance(item, int):
if not 0 <= item < schema_len:
raise ValueError(
f"Integer index {item} is out of bounds "
f"for table with {schema_len} columns (must be >= 0 and < {schema_len})."
)
index_cols.append(table.schema[item].name)
else:
raise TypeError(
"If index_col is an iterable, it must contain either strings "
"(column names) or integers (column positions)."
)
else:
index_cols = list(index_col)
raise TypeError(
f"Unsupported type for index_col: {type(index_col).__name__}. Expected"
"an integer, an string, an iterable of strings, or an iterable of integers."
)

# If the isn't an index selected, use the primary keys of the table as the
# index. If there are no primary keys, we'll return an empty list.
Expand Down
12 changes: 10 additions & 2 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ def read_gbq_table(
self,
query: str,
*,
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (),
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind = (),
columns: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str = "read_gbq_table",
Expand Down Expand Up @@ -516,7 +520,11 @@ def read_bigquery_load_job(
filepath_or_buffer: str | IO["bytes"],
*,
job_config: bigquery.LoadJobConfig,
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (),
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind = (),
columns: Iterable[str] = (),
) -> dataframe.DataFrame:
# Need to create session table beforehand
Expand Down
73 changes: 52 additions & 21 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,55 +1216,86 @@ def test_read_csv_for_local_file_w_sep(session, df_and_local_csv, sep):
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())


def test_read_csv_w_index_col_false(session, df_and_local_csv):
@pytest.mark.parametrize(
"index_col",
[
pytest.param(None, id="none"),
pytest.param(False, id="false"),
pytest.param([], id="empty_list"),
],
)
def test_read_csv_for_index_col_w_false(session, df_and_local_csv, index_col):
# Compares results for pandas and bigframes engines
scalars_df, path = df_and_local_csv
with open(path, "rb") as buffer:
bf_df = session.read_csv(
buffer,
engine="bigquery",
index_col=False,
index_col=index_col,
)
with open(path, "rb") as buffer:
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
pd_df = session.read_csv(
buffer, index_col=False, dtype=scalars_df.dtypes.to_dict()
buffer, index_col=index_col, dtype=scalars_df.dtypes.to_dict()
)

assert bf_df.shape[0] == scalars_df.shape[0]
assert bf_df.shape[0] == pd_df.shape[0]

# We use a default index because of index_col=False, so the previous index
# column is just loaded as a column.
assert len(bf_df.columns) == len(scalars_df.columns) + 1
assert len(bf_df.columns) == len(pd_df.columns)
assert bf_df.shape == pd_df.shape

# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
# (b/280889935) or guarantee row ordering.
bf_df = bf_df.set_index("rowindex").sort_index()
pd_df = pd_df.set_index("rowindex")

pd.testing.assert_frame_equal(bf_df.to_pandas(), scalars_df.to_pandas())
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())


def test_read_csv_w_index_col_column_label(session, df_and_gcs_csv):
scalars_df, path = df_and_gcs_csv
bf_df = session.read_csv(path, engine="bigquery", index_col="rowindex")
@pytest.mark.parametrize(
"index_col",
[
pytest.param("rowindex", id="single_str"),
pytest.param(["rowindex", "bool_col"], id="multi_str"),
pytest.param(0, id="single_int"),
pytest.param([0, 2], id="multi_int"),
pytest.param([0, "bool_col"], id="mix_types"),
],
)
def test_read_csv_for_index_col(session, df_and_gcs_csv, index_col):
scalars_pandas_df, path = df_and_gcs_csv
bf_df = session.read_csv(path, engine="bigquery", index_col=index_col)

# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
pd_df = session.read_csv(
path, index_col="rowindex", dtype=scalars_df.dtypes.to_dict()
path, index_col=index_col, dtype=scalars_pandas_df.dtypes.to_dict()
)

assert bf_df.shape == scalars_df.shape
assert bf_df.shape == pd_df.shape
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())

assert len(bf_df.columns) == len(scalars_df.columns)
assert len(bf_df.columns) == len(pd_df.columns)

pd.testing.assert_frame_equal(bf_df.to_pandas(), scalars_df.to_pandas())
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
@pytest.mark.parametrize(
("index_col", "error_type", "error_msg"),
[
pytest.param(
True, ValueError, "The value of index_col couldn't be 'True'", id="true"
),
pytest.param(100, ValueError, "out of bounds", id="single_int"),
pytest.param([0, 200], ValueError, "out of bounds", id="multi_int"),
pytest.param(
[0.1], TypeError, "it must contain either strings", id="invalid_iterable"
),
pytest.param(
3.14, TypeError, "Unsupported type for index_col", id="unsupported_type"
),
],
)
def test_read_csv_raises_error_for_invalid_index_col(
session, df_and_gcs_csv, index_col, error_type, error_msg
):
_, path = df_and_gcs_csv
with pytest.raises(
error_type,
match=error_msg,
):
session.read_csv(path, engine="bigquery", index_col=index_col)


@pytest.mark.parametrize(
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@
"BigQuery engine does not support these arguments",
id="with_dtype",
),
pytest.param(
{"engine": "bigquery", "index_col": 5},
"BigQuery engine only supports a single column name for `index_col`.",
id="with_index_col_not_str",
),
pytest.param(
{"engine": "bigquery", "usecols": [1, 2]},
"BigQuery engine only supports an iterable of strings for `usecols`.",
Expand Down