Skip to content

fix: correct read_csv behaviours with use_cols, names, index_col #1804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions bigframes/session/_io/bigquery/read_gbq_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,25 +243,17 @@ def get_index_cols(
| int
| bigframes.enums.DefaultIndexKind,
*,
names: Optional[Iterable[str]] = None,
rename_to_schema: Optional[Dict[str, str]] = None,
) -> List[str]:
"""
If we can get a total ordering from the table, such as via primary key
column(s), then return those too so that ordering generation can be
avoided.
"""

# Transform index_col -> index_cols so we have a variable that is
# always a list of column names (possibly empty).
schema_len = len(table.schema)

# If the `names` is provided, the index_col provided by the user is the new
# name, so we need to rename it to the original name in the table schema.
renamed_schema: Optional[Dict[str, str]] = None
if names is not None:
assert len(list(names)) == schema_len
renamed_schema = {name: field.name for name, field in zip(names, table.schema)}

index_cols: List[str] = []
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
Expand All @@ -278,8 +270,8 @@ def get_index_cols(
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
)
elif isinstance(index_col, str):
if renamed_schema is not None:
index_col = renamed_schema.get(index_col, index_col)
if rename_to_schema is not None:
index_col = rename_to_schema.get(index_col, index_col)
index_cols = [index_col]
elif isinstance(index_col, int):
if not 0 <= index_col < schema_len:
Expand All @@ -291,8 +283,8 @@ def get_index_cols(
elif isinstance(index_col, Iterable):
for item in index_col:
if isinstance(item, str):
if renamed_schema is not None:
item = renamed_schema.get(item, item)
if rename_to_schema is not None:
item = rename_to_schema.get(item, item)
index_cols.append(item)
elif isinstance(item, int):
if not 0 <= item < schema_len:
Expand Down
245 changes: 151 additions & 94 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,35 @@ def _to_index_cols(
return index_cols


def _check_column_duplicates(
index_cols: Iterable[str], columns: Iterable[str], index_col_in_columns: bool
) -> Iterable[str]:
"""Validates and processes index and data columns for duplicates and overlap.
def _check_duplicates(name: str, columns: Optional[Iterable[str]] = None):
"""Check for duplicate column names in the provided iterable."""
if columns is None:
return
columns_list = list(columns)
set_columns = set(columns_list)
if len(columns_list) > len(set_columns):
raise ValueError(
f"The '{name}' argument contains duplicate names. "
f"All column names specified in '{name}' must be unique."
)

This function performs two main tasks:
1. Ensures there are no duplicate column names within the `index_cols` list
or within the `columns` list.
2. Based on the `index_col_in_columns` flag, it validates the relationship
between `index_cols` and `columns`.

def _check_index_col_param(
index_cols: Iterable[str],
columns: Iterable[str],
*,
table_columns: Optional[Iterable[str]] = None,
index_col_in_columns: Optional[bool] = False,
):
"""Checks for duplicates in `index_cols` and resolves overlap with `columns`.

Args:
index_cols (Iterable[str]):
An iterable of column names designated as the index.
Column names designated as the index columns.
columns (Iterable[str]):
An iterable of column names designated as the data columns.
Used column names from table_columns.
table_columns (Iterable[str]):
A full list of column names in the table schema.
index_col_in_columns (bool):
A flag indicating how to handle overlap between `index_cols` and
`columns`.
Expand All @@ -121,40 +134,97 @@ def _check_column_duplicates(
`columns`. An error is raised if an index column is not found
in the `columns` list.
"""
index_cols_list = list(index_cols) if index_cols is not None else []
columns_list = list(columns) if columns is not None else []
set_index = set(index_cols_list)
set_columns = set(columns_list)
_check_duplicates("index_col", index_cols)

if len(index_cols_list) > len(set_index):
raise ValueError(
"The 'index_col' argument contains duplicate names. "
"All column names specified in 'index_col' must be unique."
)
if columns is not None and len(list(columns)) > 0:
set_index = set(list(index_cols) if index_cols is not None else [])
set_columns = set(list(columns) if columns is not None else [])

if len(columns_list) == 0:
return columns
if index_col_in_columns:
if not set_index.issubset(set_columns):
raise ValueError(
f"The specified index column(s) were not found: {set_index - set_columns}. "
f"Available columns are: {set_columns}"
)
else:
if not set_index.isdisjoint(set_columns):
raise ValueError(
"Found column names that exist in both 'index_col' and 'columns' arguments. "
"These arguments must specify distinct sets of columns."
)

if len(columns_list) > len(set_columns):
raise ValueError(
"The 'columns' argument contains duplicate names. "
"All column names specified in 'columns' must be unique."
)
if not index_col_in_columns and table_columns is not None:
for key in index_cols:
if key not in table_columns:
possibility = min(
table_columns,
key=lambda item: bigframes._tools.strings.levenshtein_distance(
key, item
),
)
raise ValueError(
f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?"
)

if index_col_in_columns:
if not set_index.issubset(set_columns):
raise ValueError(
f"The specified index column(s) were not found: {set_index - set_columns}. "
f"Available columns are: {set_columns}"

def _check_columns_param(columns: Iterable[str], table_columns: Iterable[str]):
"""Validates that the specified columns are present in the table columns.

Args:
columns (Iterable[str]):
Used column names from table_columns.
table_columns (Iterable[str]):
A full list of column names in the table schema.
Raises:
ValueError: If any column in `columns` is not found in the table columns.
"""
for column_name in columns:
if column_name not in table_columns:
possibility = min(
table_columns,
key=lambda item: bigframes._tools.strings.levenshtein_distance(
column_name, item
),
)
return [col for col in columns if col not in set_index]
else:
if not set_index.isdisjoint(set_columns):
raise ValueError(
"Found column names that exist in both 'index_col' and 'columns' arguments. "
"These arguments must specify distinct sets of columns."
f"Column '{column_name}' is not found. Did you mean '{possibility}'?"
)
return columns


def _check_names_param(
names: Iterable[str],
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind,
columns: Iterable[str],
table_columns: Iterable[str],
):
len_names = len(list(names))
len_table_columns = len(list(table_columns))
len_columns = len(list(columns))
if len_names > len_table_columns:
raise ValueError(
f"Too many columns specified: expected {len_table_columns}"
f" and found {len_names}"
)
elif len_names < len_table_columns:
if isinstance(index_col, bigframes.enums.DefaultIndexKind) or index_col != ():
raise KeyError(
"When providing both `index_col` and `names`, ensure the "
"number of `names` matches the number of columns in your "
"data."
)
if len_columns != 0:
# The 'columns' must be identical to the 'names'. If not, raise an error.
if len_columns != len_names:
raise ValueError(
"Number of passed names did not match number of header "
"fields in the file"
)
if set(list(names)) != set(list(columns)):
raise ValueError("Usecols do not match columns")


@dataclasses.dataclass
Expand Down Expand Up @@ -545,11 +615,14 @@ def read_gbq_table(
f"`max_results` should be a positive number, got {max_results}."
)

_check_duplicates("columns", columns)

table_ref = google.cloud.bigquery.table.TableReference.from_string(
table_id, default_project=self._bqclient.project
)

columns = list(columns)
include_all_columns = columns is None or len(columns) == 0
filters = typing.cast(list, list(filters))

# ---------------------------------
Expand All @@ -563,72 +636,58 @@ def read_gbq_table(
cache=self._df_snapshot,
use_cache=use_cache,
)
table_column_names = {field.name for field in table.schema}

if table.location.casefold() != self._storage_manager.location.casefold():
raise ValueError(
f"Current session is in {self._storage_manager.location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
)

for key in columns:
if key not in table_column_names:
possibility = min(
table_column_names,
key=lambda item: bigframes._tools.strings.levenshtein_distance(
key, item
),
)
raise ValueError(
f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?"
)

# TODO(b/408499371): check `names` work with `use_cols` for read_csv method.
table_column_names = [field.name for field in table.schema]
rename_to_schema: Optional[Dict[str, str]] = None
if names is not None:
_check_names_param(names, index_col, columns, table_column_names)

# Additional unnamed columns is going to set as index columns
len_names = len(list(names))
len_columns = len(table.schema)
if len_names > len_columns:
raise ValueError(
f"Too many columns specified: expected {len_columns}"
f" and found {len_names}"
)
elif len_names < len_columns:
if (
isinstance(index_col, bigframes.enums.DefaultIndexKind)
or index_col != ()
):
raise KeyError(
"When providing both `index_col` and `names`, ensure the "
"number of `names` matches the number of columns in your "
"data."
)
index_col = range(len_columns - len_names)
len_schema = len(table.schema)
if len(columns) == 0 and len_names < len_schema:
index_col = range(len_schema - len_names)
names = [
field.name for field in table.schema[: len_columns - len_names]
field.name for field in table.schema[: len_schema - len_names]
] + list(names)

assert len_schema >= len_names
assert len_names >= len(columns)

table_column_names = table_column_names[: len(list(names))]
rename_to_schema = dict(zip(list(names), table_column_names))

if len(columns) != 0:
if names is None:
_check_columns_param(columns, table_column_names)
else:
_check_columns_param(columns, names)
names = columns
assert rename_to_schema is not None
columns = [rename_to_schema[renamed_name] for renamed_name in columns]

# Converting index_col into a list of column names requires
# the table metadata because we might use the primary keys
# when constructing the index.
index_cols = bf_read_gbq_table.get_index_cols(
table=table,
index_col=index_col,
names=names,
rename_to_schema=rename_to_schema,
)
columns = list(
_check_column_duplicates(index_cols, columns, index_col_in_columns)
_check_index_col_param(
index_cols,
columns,
table_columns=table_column_names,
index_col_in_columns=index_col_in_columns,
)

for key in index_cols:
if key not in table_column_names:
possibility = min(
table_column_names,
key=lambda item: bigframes._tools.strings.levenshtein_distance(
key, item
),
)
raise ValueError(
f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?"
)
if index_col_in_columns and not include_all_columns:
set_index = set(list(index_cols) if index_cols is not None else [])
columns = [col for col in columns if col not in set_index]

# -----------------------------
# Optionally, execute the query
Expand Down Expand Up @@ -715,7 +774,7 @@ def read_gbq_table(
metadata_only=not self._scan_index_uniqueness,
)
schema = schemata.ArraySchema.from_bq_table(table)
if columns:
if not include_all_columns:
schema = schema.select(index_cols + columns)
array_value = core.ArrayValue.from_table(
table,
Expand Down Expand Up @@ -767,14 +826,14 @@ def read_gbq_table(

value_columns = [col for col in array_value.column_ids if col not in index_cols]
if names is not None:
renamed_cols: Dict[str, str] = {
col: new_name for col, new_name in zip(array_value.column_ids, names)
}
assert rename_to_schema is not None
schema_to_rename = {value: key for key, value in rename_to_schema.items()}
if index_col != bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
index_names = [
renamed_cols.get(index_col, index_col) for index_col in index_cols
schema_to_rename.get(index_col, index_col)
for index_col in index_cols
]
value_columns = [renamed_cols.get(col, col) for col in value_columns]
value_columns = [schema_to_rename.get(col, col) for col in value_columns]

block = blocks.Block(
array_value,
Expand Down Expand Up @@ -898,9 +957,7 @@ def read_gbq_query(
)

index_cols = _to_index_cols(index_col)
columns = _check_column_duplicates(
index_cols, columns, index_col_in_columns=False
)
_check_index_col_param(index_cols, columns)

filters_copy1, filters_copy2 = itertools.tee(filters)
has_filters = len(list(filters_copy1)) != 0
Expand Down
Loading