Skip to content

Commit 0785cf8

Browse files
committed
fix: read_csv with both index_col and use_cols inconsistent with pandas
1 parent 38d9b73 commit 0785cf8

File tree

3 files changed

+84
-29
lines changed

3 files changed

+84
-29
lines changed

‎bigframes/session/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,11 @@ def _read_csv_w_bigquery_engine(
11661166

11671167
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
11681168
df = self._loader.read_gbq_table(
1169-
table_id, index_col=index_col, columns=columns, names=names
1169+
table_id,
1170+
index_col=index_col,
1171+
columns=columns,
1172+
names=names,
1173+
is_index_in_columns=True,
11701174
)
11711175

11721176
if dtype is not None:

‎bigframes/session/loader.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def _to_index_cols(
9393
return index_cols
9494

9595

96-
def _check_column_duplicates(index_cols: Iterable[str], columns: Iterable[str]):
96+
def _check_column_duplicates(
97+
index_cols: Iterable[str], columns: Iterable[str], is_index_in_columns: bool
98+
) -> List[str]:
9799
index_cols_list = list(index_cols) if index_cols is not None else []
98100
columns_list = list(columns) if columns is not None else []
99101
set_index = set(index_cols_list)
@@ -105,17 +107,29 @@ def _check_column_duplicates(index_cols: Iterable[str], columns: Iterable[str]):
105107
"All column names specified in 'index_col' must be unique."
106108
)
107109

110+
if len(columns_list) == 0:
111+
return columns_list
112+
108113
if len(columns_list) > len(set_columns):
109114
raise ValueError(
110115
"The 'columns' argument contains duplicate names. "
111116
"All column names specified in 'columns' must be unique."
112117
)
113118

114-
if not set_index.isdisjoint(set_columns):
115-
raise ValueError(
116-
"Found column names that exist in both 'index_col' and 'columns' arguments. "
117-
"These arguments must specify distinct sets of columns."
118-
)
119+
if is_index_in_columns:
120+
if not set_index.issubset(set_columns):
121+
raise ValueError(
122+
f"The specified index column(s) were not found: {set_index - set_columns}. "
123+
f"Available columns are: {set_columns}"
124+
)
125+
return list(set_columns - set_index)
126+
else:
127+
if not set_index.isdisjoint(set_columns):
128+
raise ValueError(
129+
"Found column names that exist in both 'index_col' and 'columns' arguments. "
130+
"These arguments must specify distinct sets of columns."
131+
)
132+
return columns_list
119133

120134

121135
@dataclasses.dataclass
@@ -388,6 +402,7 @@ def read_gbq_table( # type: ignore[overload-overlap]
388402
dry_run: Literal[False] = ...,
389403
force_total_order: Optional[bool] = ...,
390404
n_rows: Optional[int] = None,
405+
is_index_in_columns: bool = False,
391406
) -> dataframe.DataFrame:
392407
...
393408

@@ -410,6 +425,7 @@ def read_gbq_table(
410425
dry_run: Literal[True] = ...,
411426
force_total_order: Optional[bool] = ...,
412427
n_rows: Optional[int] = None,
428+
is_index_in_columns: bool = False,
413429
) -> pandas.Series:
414430
...
415431

@@ -431,6 +447,7 @@ def read_gbq_table(
431447
dry_run: bool = False,
432448
force_total_order: Optional[bool] = None,
433449
n_rows: Optional[int] = None,
450+
is_index_in_columns: bool = False,
434451
) -> dataframe.DataFrame | pandas.Series:
435452
import bigframes._tools.strings
436453
import bigframes.dataframe as dataframe
@@ -513,7 +530,7 @@ def read_gbq_table(
513530
index_col=index_col,
514531
names=names,
515532
)
516-
_check_column_duplicates(index_cols, columns)
533+
columns = _check_column_duplicates(index_cols, columns, is_index_in_columns)
517534

518535
for key in index_cols:
519536
if key not in table_column_names:
@@ -794,7 +811,9 @@ def read_gbq_query(
794811
)
795812

796813
index_cols = _to_index_cols(index_col)
797-
_check_column_duplicates(index_cols, columns)
814+
columns = _check_column_duplicates(
815+
index_cols, columns, is_index_in_columns=False
816+
)
798817

799818
filters_copy1, filters_copy2 = itertools.tee(filters)
800819
has_filters = len(list(filters_copy1)) != 0

‎tests/system/small/test_session.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,41 +1479,73 @@ def test_read_csv_for_gcs_file_w_header(session, df_and_gcs_csv, header):
14791479
def test_read_csv_w_usecols(session, df_and_local_csv):
14801480
# Compares results for pandas and bigframes engines
14811481
scalars_df, path = df_and_local_csv
1482+
usecols = ["rowindex", "bool_col"]
14821483
with open(path, "rb") as buffer:
14831484
bf_df = session.read_csv(
14841485
buffer,
14851486
engine="bigquery",
1486-
usecols=["bool_col"],
1487+
usecols=usecols,
14871488
)
14881489
with open(path, "rb") as buffer:
14891490
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
14901491
pd_df = session.read_csv(
14911492
buffer,
1492-
usecols=["bool_col"],
1493+
usecols=usecols,
14931494
dtype=scalars_df[["bool_col"]].dtypes.to_dict(),
14941495
)
14951496

1496-
# Cannot compare two dataframe due to b/408499371.
1497-
assert len(bf_df.columns) == 1
1498-
assert len(pd_df.columns) == 1
1497+
assert bf_df.shape == pd_df.shape
1498+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
14991499

1500+
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1501+
# (b/280889935) or guarantee row ordering.
1502+
bf_df = bf_df.set_index("rowindex").sort_index()
1503+
pd_df = pd_df.set_index("rowindex")
1504+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
15001505

1501-
@pytest.mark.parametrize(
1502-
"engine",
1503-
[
1504-
pytest.param("bigquery", id="bq_engine"),
1505-
pytest.param(None, id="default_engine"),
1506-
],
1507-
)
1508-
def test_read_csv_local_w_usecols(session, scalars_pandas_df_index, engine):
1509-
with tempfile.TemporaryDirectory() as dir:
1510-
path = dir + "/test_read_csv_local_w_usecols.csv"
1511-
# Using the pandas to_csv method because the BQ one does not support local write.
1512-
scalars_pandas_df_index.to_csv(path, index=False)
15131506

1514-
# df should only have 1 column which is bool_col.
1515-
df = session.read_csv(path, usecols=["bool_col"], engine=engine)
1516-
assert len(df.columns) == 1
1507+
def test_read_csv_w_usecols_and_indexcol(session, df_and_local_csv):
1508+
# Compares results for pandas and bigframes engines
1509+
scalars_df, path = df_and_local_csv
1510+
usecols = ["rowindex", "bool_col"]
1511+
with open(path, "rb") as buffer:
1512+
bf_df = session.read_csv(
1513+
buffer,
1514+
engine="bigquery",
1515+
usecols=usecols,
1516+
index_col="rowindex",
1517+
)
1518+
with open(path, "rb") as buffer:
1519+
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
1520+
pd_df = session.read_csv(
1521+
buffer,
1522+
usecols=usecols,
1523+
index_col="rowindex",
1524+
dtype=scalars_df[["bool_col"]].dtypes.to_dict(),
1525+
)
1526+
1527+
assert bf_df.shape == pd_df.shape
1528+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
1529+
1530+
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1531+
# (b/280889935) or guarantee row ordering.
1532+
bf_df = bf_df.sort_index()
1533+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
1534+
1535+
1536+
def test_read_csv_w_indexcol_not_in_usecols(session, df_and_local_csv):
1537+
_, path = df_and_local_csv
1538+
with open(path, "rb") as buffer:
1539+
with pytest.raises(
1540+
ValueError,
1541+
match=re.escape("The specified index column(s) were not found"),
1542+
):
1543+
session.read_csv(
1544+
buffer,
1545+
engine="bigquery",
1546+
usecols=["bool_col"],
1547+
index_col="rowindex",
1548+
)
15171549

15181550

15191551
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)