Skip to content

Commit 5934f8e

Browse files
feat: Support DataFrame.astype(dict) (#1262)
* [WIP] Support dict dtypes for df.astype() * feat: Support DataFrame.astype(dict) * add test cases for failure * remove test notebook * re-write type-checking logic to make mypy happy * fix format * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * use reflection in dtypes module, and update error type --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 1b40a11 commit 5934f8e

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

‎bigframes/dataframe.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,35 @@ def __iter__(self):
367367

368368
def astype(
369369
self,
370-
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
370+
dtype: Union[
371+
bigframes.dtypes.DtypeString,
372+
bigframes.dtypes.Dtype,
373+
dict[str, Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype]],
374+
],
371375
*,
372376
errors: Literal["raise", "null"] = "raise",
373377
) -> DataFrame:
374378
if errors not in ["raise", "null"]:
375379
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
376-
return self._apply_unary_op(
377-
ops.AsTypeOp(to_type=dtype, safe=(errors == "null"))
380+
381+
safe_cast = errors == "null"
382+
383+
# Type strings check
384+
if dtype in bigframes.dtypes.DTYPE_STRINGS:
385+
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))
386+
387+
# Type instances check
388+
if type(dtype) in bigframes.dtypes.DTYPES:
389+
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))
390+
391+
if isinstance(dtype, dict):
392+
result = self.copy()
393+
for col, to_type in dtype.items():
394+
result[col] = result[col].astype(to_type)
395+
return result
396+
397+
raise TypeError(
398+
f"Invalid type {type(dtype)} for dtype input. {constants.FEEDBACK_LINK}"
378399
)
379400

380401
def _to_sql_query(

‎bigframes/dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
pd.ArrowDtype,
3737
gpd.array.GeometryDtype,
3838
]
39+
40+
DTYPES = typing.get_args(Dtype)
3941
# Represents both column types (dtypes) and local-only types
4042
# None represents the type of a None scalar.
4143
ExpressionType = typing.Optional[Dtype]
@@ -238,6 +240,8 @@ class SimpleDtypeInfo:
238240
"binary[pyarrow]",
239241
]
240242

243+
DTYPE_STRINGS = typing.get_args(DtypeString)
244+
241245
BOOL_BIGFRAMES_TYPES = [BOOL_DTYPE]
242246

243247
# Corresponds to the pandas concept of numeric type (such as when 'numeric_only' is specified in an operation)

‎tests/system/small/test_dataframe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5199,3 +5199,33 @@ def test__resample_start_time(rule, origin, data):
51995199
pd.testing.assert_frame_equal(
52005200
bf_result, pd_result, check_dtype=False, check_index_type=False
52015201
)
5202+
5203+
5204+
@pytest.mark.parametrize(
5205+
"dtype",
5206+
[
5207+
pytest.param("string[pyarrow]", id="type-string"),
5208+
pytest.param(pd.StringDtype(storage="pyarrow"), id="type-literal"),
5209+
pytest.param(
5210+
{"bool_col": "string[pyarrow]", "int64_col": pd.Float64Dtype()},
5211+
id="multiple-types",
5212+
),
5213+
],
5214+
)
5215+
def test_astype(scalars_dfs, dtype):
5216+
bf_df, pd_df = scalars_dfs
5217+
target_cols = ["bool_col", "int64_col"]
5218+
bf_df = bf_df[target_cols]
5219+
pd_df = pd_df[target_cols]
5220+
5221+
bf_result = bf_df.astype(dtype).to_pandas()
5222+
pd_result = pd_df.astype(dtype)
5223+
5224+
pd.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False)
5225+
5226+
5227+
def test_astype_invalid_type_fail(scalars_dfs):
5228+
bf_df, _ = scalars_dfs
5229+
5230+
with pytest.raises(TypeError, match=r".*Share your usecase with.*"):
5231+
bf_df.astype(123)

0 commit comments

Comments
 (0)