Skip to content

Commit b503355

Browse files
feat: Add DataFrame.corrwith method (#1315)
1 parent dad522d commit b503355

File tree

5 files changed

+160
-8
lines changed

5 files changed

+160
-8
lines changed

‎bigframes/core/blocks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,7 +2152,7 @@ def merge(
21522152

21532153
def _align_both_axes(
21542154
self, other: Block, how: str
2155-
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2155+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.RefOrConstant, ex.RefOrConstant]]]:
21562156
# Join rows
21572157
aligned_block, (get_column_left, get_column_right) = self.join(other, how=how)
21582158
# join columns schema
@@ -2161,7 +2161,7 @@ def _align_both_axes(
21612161
columns, lcol_indexer, rcol_indexer = self.column_labels, None, None
21622162
else:
21632163
columns, lcol_indexer, rcol_indexer = self.column_labels.join(
2164-
other.column_labels, how="outer", return_indexers=True
2164+
other.column_labels, how=how, return_indexers=True
21652165
)
21662166
lcol_indexer = (
21672167
lcol_indexer if (lcol_indexer is not None) else range(len(columns))
@@ -2183,11 +2183,11 @@ def _align_both_axes(
21832183

21842184
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
21852185
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
2186-
return aligned_block, columns, tuple(zip(left_inputs, right_inputs))
2186+
return aligned_block, columns, tuple(zip(left_inputs, right_inputs)) # type: ignore
21872187

21882188
def _align_axis_0(
21892189
self, other: Block, how: str
2190-
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2190+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.DerefOp, ex.DerefOp]]]:
21912191
assert len(other.value_columns) == 1
21922192
aligned_block, (get_column_left, get_column_right) = self.join(other, how=how)
21932193

@@ -2203,7 +2203,7 @@ def _align_axis_0(
22032203

22042204
def _align_series_block_axis_1(
22052205
self, other: Block, how: str
2206-
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2206+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.RefOrConstant, ex.RefOrConstant]]]:
22072207
assert len(other.value_columns) == 1
22082208
if other._transpose_cache is None:
22092209
raise ValueError(
@@ -2244,11 +2244,11 @@ def _align_series_block_axis_1(
22442244

22452245
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
22462246
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
2247-
return aligned_block, columns, tuple(zip(left_inputs, right_inputs))
2247+
return aligned_block, columns, tuple(zip(left_inputs, right_inputs)) # type: ignore
22482248

22492249
def _align_pd_series_axis_1(
22502250
self, other: pd.Series, how: str
2251-
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2251+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.RefOrConstant, ex.RefOrConstant]]]:
22522252
if self.column_labels.equals(other.index):
22532253
columns, lcol_indexer, rcol_indexer = self.column_labels, None, None
22542254
else:
@@ -2275,7 +2275,7 @@ def _align_pd_series_axis_1(
22752275

22762276
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
22772277
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
2278-
return self, columns, tuple(zip(left_inputs, right_inputs))
2278+
return self, columns, tuple(zip(left_inputs, right_inputs)) # type: ignore
22792279

22802280
def _apply_binop(
22812281
self,

‎bigframes/core/expression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,6 @@ def deterministic(self) -> bool:
420420
return (
421421
all(input.deterministic for input in self.inputs) and self.op.deterministic
422422
)
423+
424+
425+
RefOrConstant = Union[DerefOp, ScalarConstantExpression]

‎bigframes/dataframe.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,48 @@ def cov(self, *, numeric_only: bool = False) -> DataFrame:
14731473

14741474
return result
14751475

1476+
def corrwith(
1477+
self,
1478+
other: typing.Union[DataFrame, bigframes.series.Series],
1479+
*,
1480+
numeric_only: bool = False,
1481+
):
1482+
other_frame = other if isinstance(other, DataFrame) else other.to_frame()
1483+
if numeric_only:
1484+
l_frame = self._drop_non_numeric()
1485+
r_frame = other_frame._drop_non_numeric()
1486+
else:
1487+
l_frame = self._raise_on_non_numeric("corrwith")
1488+
r_frame = other_frame._raise_on_non_numeric("corrwith")
1489+
1490+
l_block = l_frame.astype(bigframes.dtypes.FLOAT_DTYPE)._block
1491+
r_block = r_frame.astype(bigframes.dtypes.FLOAT_DTYPE)._block
1492+
1493+
if isinstance(other, DataFrame):
1494+
block, labels, expr_pairs = l_block._align_both_axes(r_block, how="inner")
1495+
else:
1496+
assert isinstance(other, bigframes.series.Series)
1497+
block, labels, expr_pairs = l_block._align_axis_0(r_block, how="inner")
1498+
1499+
na_cols = l_block.column_labels.join(
1500+
r_block.column_labels, how="outer"
1501+
).difference(labels)
1502+
1503+
block, _ = block.aggregate(
1504+
aggregations=tuple(
1505+
ex.BinaryAggregation(agg_ops.CorrOp(), left_ex, right_ex)
1506+
for left_ex, right_ex in expr_pairs
1507+
),
1508+
column_labels=labels,
1509+
)
1510+
block = block.project_exprs(
1511+
(ex.const(float("nan")),) * len(na_cols), labels=na_cols
1512+
)
1513+
block = block.transpose(
1514+
original_row_index=pandas.Index([None]), single_row_mode=True
1515+
)
1516+
return bigframes.pandas.Series(block)
1517+
14761518
def to_arrow(
14771519
self,
14781520
*,

‎tests/system/small/test_dataframe.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,72 @@ def test_cov_w_numeric_only(scalars_dfs_maybe_ordered, columns, numeric_only):
22462246
)
22472247

22482248

2249+
def test_df_corrwith_df(scalars_dfs_maybe_ordered):
2250+
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered
2251+
2252+
l_cols = ["int64_col", "float64_col", "int64_too"]
2253+
r_cols = ["int64_too", "float64_col"]
2254+
2255+
bf_result = scalars_df[l_cols].corrwith(scalars_df[r_cols]).to_pandas()
2256+
pd_result = scalars_pandas_df[l_cols].corrwith(scalars_pandas_df[r_cols])
2257+
2258+
# BigFrames and Pandas differ in their data type handling:
2259+
# - Column types: BigFrames uses Float64, Pandas uses float64.
2260+
# - Index types: BigFrames uses strign, Pandas uses object.
2261+
pd.testing.assert_series_equal(
2262+
bf_result, pd_result, check_dtype=False, check_index_type=False
2263+
)
2264+
2265+
2266+
def test_df_corrwith_df_numeric_only(scalars_dfs):
2267+
scalars_df, scalars_pandas_df = scalars_dfs
2268+
2269+
l_cols = ["int64_col", "float64_col", "int64_too", "string_col"]
2270+
r_cols = ["int64_too", "float64_col", "bool_col"]
2271+
2272+
bf_result = (
2273+
scalars_df[l_cols].corrwith(scalars_df[r_cols], numeric_only=True).to_pandas()
2274+
)
2275+
pd_result = scalars_pandas_df[l_cols].corrwith(
2276+
scalars_pandas_df[r_cols], numeric_only=True
2277+
)
2278+
2279+
# BigFrames and Pandas differ in their data type handling:
2280+
# - Column types: BigFrames uses Float64, Pandas uses float64.
2281+
# - Index types: BigFrames uses strign, Pandas uses object.
2282+
pd.testing.assert_series_equal(
2283+
bf_result, pd_result, check_dtype=False, check_index_type=False
2284+
)
2285+
2286+
2287+
def test_df_corrwith_df_non_numeric_error(scalars_dfs):
2288+
scalars_df, _ = scalars_dfs
2289+
2290+
l_cols = ["int64_col", "float64_col", "int64_too", "string_col"]
2291+
r_cols = ["int64_too", "float64_col", "bool_col"]
2292+
2293+
with pytest.raises(NotImplementedError):
2294+
scalars_df[l_cols].corrwith(scalars_df[r_cols], numeric_only=False)
2295+
2296+
2297+
@skip_legacy_pandas
2298+
def test_df_corrwith_series(scalars_dfs_maybe_ordered):
2299+
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered
2300+
2301+
l_cols = ["int64_col", "float64_col", "int64_too"]
2302+
r_col = "float64_col"
2303+
2304+
bf_result = scalars_df[l_cols].corrwith(scalars_df[r_col]).to_pandas()
2305+
pd_result = scalars_pandas_df[l_cols].corrwith(scalars_pandas_df[r_col])
2306+
2307+
# BigFrames and Pandas differ in their data type handling:
2308+
# - Column types: BigFrames uses Float64, Pandas uses float64.
2309+
# - Index types: BigFrames uses strign, Pandas uses object.
2310+
pd.testing.assert_series_equal(
2311+
bf_result, pd_result, check_dtype=False, check_index_type=False
2312+
)
2313+
2314+
22492315
@pytest.mark.parametrize(
22502316
("op"),
22512317
[

‎third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,6 +4146,47 @@ def cov(self, *, numeric_only) -> DataFrame:
41464146
"""
41474147
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
41484148

4149+
def corrwith(
4150+
self,
4151+
other,
4152+
*,
4153+
numeric_only: bool = False,
4154+
):
4155+
"""
4156+
Compute pairwise correlation.
4157+
4158+
Pairwise correlation is computed between rows or columns of
4159+
DataFrame with rows or columns of Series or DataFrame. DataFrames
4160+
are first aligned along both axes before computing the
4161+
correlations.
4162+
4163+
**Examples:**
4164+
>>> import bigframes.pandas as bpd
4165+
>>> bpd.options.display.progress_bar = None
4166+
4167+
>>> index = ["a", "b", "c", "d", "e"]
4168+
>>> columns = ["one", "two", "three", "four"]
4169+
>>> df1 = bpd.DataFrame(np.arange(20).reshape(5, 4), index=index, columns=columns)
4170+
>>> df2 = bpd.DataFrame(np.arange(16).reshape(4, 4), index=index[:4], columns=columns)
4171+
>>> df1.corrwith(df2)
4172+
one 1.0
4173+
two 1.0
4174+
three 1.0
4175+
four 1.0
4176+
dtype: Float64
4177+
4178+
Args:
4179+
other (DataFrame, Series):
4180+
Object with which to compute correlations.
4181+
4182+
numeric_only (bool, default False):
4183+
Include only `float`, `int` or `boolean` data.
4184+
4185+
Returns:
4186+
bigframes.pandas.Series: Pairwise correlations.
4187+
"""
4188+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
4189+
41494190
def update(
41504191
self, other, join: str = "left", overwrite: bool = True, filter_func=None
41514192
) -> DataFrame:

0 commit comments

Comments
 (0)