Skip to content

fix: Fix bug with DataFrame.agg for string values #1870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def _generate_resample_label(
return block.set_index([resample_label_id])

def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
dtype = None
input_dtypes = []
input_columns: list[Optional[str]] = []
for uvalue in utils.index_as_tuples(stack_labels):
label_to_match = (*col_label, *uvalue)
Expand All @@ -2009,15 +2009,18 @@ def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
matching_ids = self.label_to_col_id.get(label_to_match, [])
input_id = matching_ids[0] if len(matching_ids) > 0 else None
if input_id:
if dtype and dtype != self._column_type(input_id):
raise NotImplementedError(
"Cannot stack columns with non-matching dtypes."
)
else:
dtype = self._column_type(input_id)
input_dtypes.append(self._column_type(input_id))
input_columns.append(input_id)
# Input column i is the first one that
return tuple(input_columns), dtype or pd.Float64Dtype()
if len(input_dtypes) > 0:
output_dtype = bigframes.dtypes.lcd_type(*input_dtypes)
if output_dtype is None:
raise NotImplementedError(
"Cannot stack columns with non-matching dtypes."
)
else:
output_dtype = pd.Float64Dtype()
return tuple(input_columns), output_dtype

def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype:
col_offset = self.value_columns.index(col_id)
Expand Down
46 changes: 38 additions & 8 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,14 +3004,44 @@ def agg(
if utils.is_dict_like(func):
# Must check dict-like first because dictionaries are list-like
# according to Pandas.
agg_cols = []
for col_label, agg_func in func.items():
agg_cols.append(self[col_label].agg(agg_func))

from bigframes.core.reshape import api as reshape

return reshape.concat(agg_cols, axis=1)

aggs = []
labels = []
funcnames = []
for col_label, agg_func in func.items():
agg_func_list = agg_func if utils.is_list_like(agg_func) else [agg_func]
col_id = self._block.resolve_label_exact(col_label)
if col_id is None:
raise KeyError(f"Column {col_label} does not exist")
for agg_func in agg_func_list:
agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func))
agg_expr = (
ex.UnaryAggregation(agg_op, ex.deref(col_id))
if isinstance(agg_op, agg_ops.UnaryAggregateOp)
else ex.NullaryAggregation(agg_op)
)
aggs.append(agg_expr)
labels.append(col_label)
funcnames.append(agg_func)

# if any list in dict values, format output differently
if any(utils.is_list_like(v) for v in func.values()):
new_index, _ = self.columns.reindex(labels)
new_index = utils.combine_indices(new_index, pandas.Index(funcnames))
agg_block, _ = self._block.aggregate(
aggregations=aggs, column_labels=new_index
)
return DataFrame(agg_block).stack().droplevel(0, axis="index")
else:
new_index, _ = self.columns.reindex(labels)
agg_block, _ = self._block.aggregate(
aggregations=aggs, column_labels=new_index
)
return bigframes.series.Series(
agg_block.transpose(
single_row_mode=True, original_row_index=pandas.Index([None])
)
)
elif utils.is_list_like(func):
aggregations = [agg_ops.lookup_agg_func(f) for f in func]

Expand All @@ -3027,7 +3057,7 @@ def agg(
)
)

else:
else: # function name string
return bigframes.series.Series(
self._block.aggregate_all_and_stack(
agg_ops.lookup_agg_func(typing.cast(str, func))
Expand Down
34 changes: 33 additions & 1 deletion tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5538,7 +5538,7 @@ def test_astype_invalid_type_fail(scalars_dfs):
bf_df.astype(123)


def test_agg_with_dict(scalars_dfs):
def test_agg_with_dict_lists(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": ["min", "max"],
Expand All @@ -5553,6 +5553,38 @@ def test_agg_with_dict(scalars_dfs):
)


def test_agg_with_dict_list_and_str(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": ["min", "max"],
"int64_col": "sum",
}

bf_result = bf_df.agg(agg_funcs).to_pandas()
pd_result = pd_df.agg(agg_funcs)

pd.testing.assert_frame_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


def test_agg_with_dict_strs(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": "min",
"int64_col": "sum",
"float64_col": "max",
}

bf_result = bf_df.agg(agg_funcs).to_pandas()
pd_result = pd_df.agg(agg_funcs)
pd_result.index = pd_result.index.astype("string[pyarrow]")

pd.testing.assert_series_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs):
bf_df, _ = scalars_dfs
agg_funcs = {
Expand Down