Skip to content

feat: Support DataFrame.astype(dict) #1262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 7, 2025
27 changes: 24 additions & 3 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,35 @@ def __iter__(self):

def astype(
self,
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
dtype: Union[
bigframes.dtypes.DtypeString,
bigframes.dtypes.Dtype,
dict[str, Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype]],
],
*,
errors: Literal["raise", "null"] = "raise",
) -> DataFrame:
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
return self._apply_unary_op(
ops.AsTypeOp(to_type=dtype, safe=(errors == "null"))

safe_cast = errors == "null"

# Type strings check
if dtype in bigframes.dtypes.DTYPE_STRINGS:
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))

# Type instances check
if type(dtype) in bigframes.dtypes.DTYPES:
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))

if isinstance(dtype, dict):
result = self.copy()
for col, to_type in dtype.items():
result[col] = result[col].astype(to_type)
return result

raise TypeError(
f"Invalid type {type(dtype)} for dtype input. {constants.FEEDBACK_LINK}"
)

def _to_sql_query(
Expand Down
4 changes: 4 additions & 0 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
pd.ArrowDtype,
gpd.array.GeometryDtype,
]

DTYPES = typing.get_args(Dtype)
# Represents both column types (dtypes) and local-only types
# None represents the type of a None scalar.
ExpressionType = typing.Optional[Dtype]
Expand Down Expand Up @@ -238,6 +240,8 @@ class SimpleDtypeInfo:
"binary[pyarrow]",
]

DTYPE_STRINGS = typing.get_args(DtypeString)

BOOL_BIGFRAMES_TYPES = [BOOL_DTYPE]

# Corresponds to the pandas concept of numeric type (such as when 'numeric_only' is specified in an operation)
Expand Down
30 changes: 30 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5199,3 +5199,33 @@ def test__resample_start_time(rule, origin, data):
pd.testing.assert_frame_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


@pytest.mark.parametrize(
"dtype",
[
pytest.param("string[pyarrow]", id="type-string"),
pytest.param(pd.StringDtype(storage="pyarrow"), id="type-literal"),
pytest.param(
{"bool_col": "string[pyarrow]", "int64_col": pd.Float64Dtype()},
id="multiple-types",
),
],
)
def test_astype(scalars_dfs, dtype):
bf_df, pd_df = scalars_dfs
target_cols = ["bool_col", "int64_col"]
bf_df = bf_df[target_cols]
pd_df = pd_df[target_cols]

bf_result = bf_df.astype(dtype).to_pandas()
pd_result = pd_df.astype(dtype)

pd.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False)


def test_astype_invalid_type_fail(scalars_dfs):
bf_df, _ = scalars_dfs

with pytest.raises(TypeError, match=r".*Share your usecase with.*"):
bf_df.astype(123)
Loading