Skip to content

feat: Support axis=1 in df.apply for scalar outputs #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4d3200c
feat: Support `axis=1` in `df.apply` for scalar outputs
shobsi Apr 22, 2024
ceb7dbc
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 22, 2024
22b7f32
avoid mixing other changes in the input_types param
shobsi Apr 22, 2024
5049170
use guid instead of hard coded column name
shobsi Apr 23, 2024
4f28f56
check_exact=False to avoid failing system_prerelease
shobsi Apr 23, 2024
fb92f34
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 23, 2024
2e2f32c
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 24, 2024
99e3e7b
handle index in remote function, add large system tests
shobsi Apr 24, 2024
85efbb7
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 25, 2024
7153db8
make the test case more robust
shobsi Apr 25, 2024
13c2e62
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 26, 2024
c3dddd8
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 27, 2024
5fb8148
handle non-string column names, add unsupported dtype tests
shobsi Apr 29, 2024
edbac1b
fix import
shobsi Apr 29, 2024
74aeaea
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 29, 2024
d3c07e9
use `_cached` in df.apply to catch any rf execution errors early
shobsi Apr 29, 2024
ed03d28
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 29, 2024
7122b8a
add test for row aggregates
shobsi Apr 29, 2024
8957e10
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi Apr 30, 2024
9f9b61e
add row dtype information, also test
shobsi Apr 30, 2024
6fdd282
preserve the order of input in the output
shobsi May 1, 2024
37906ee
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 1, 2024
2d137ca
absorb to_numpy() disparity in prerelease tests
shobsi May 1, 2024
3e45f78
add tests for column multiindex and non remote function
shobsi May 2, 2024
e31a09d
add preview note for row processing
shobsi May 2, 2024
b828860
add warning for input_types="row" and axis=1
shobsi May 2, 2024
eb383f3
introduce early check on the supported dtypes
shobsi May 2, 2024
d520337
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 2, 2024
7a3aa5f
asjust test after early dtype handling
shobsi May 2, 2024
4d39204
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 2, 2024
7383faf
address review comments
shobsi May 4, 2024
a8f036a
Merge remote-tracking branch 'refs/remotes/github/main'
shobsi May 4, 2024
84d719c
user NameError for column name parsing issue, address test coverage f…
shobsi May 4, 2024
4e96b96
address nan return handling in the gcf code
shobsi May 7, 2024
4ce3cc9
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 7, 2024
612055d
handle (nan, inf, -inf)
shobsi May 7, 2024
1c58ded
replace "row" by bpd.Series for input types
shobsi May 7, 2024
bede078
make the bq parity assert more readable
shobsi May 7, 2024
56a8236
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 7, 2024
3409bc3
fix the series name before assert
shobsi May 8, 2024
ea7e28e
fix docstring for args
shobsi May 8, 2024
14602f8
move more low level string logic in sql module
shobsi May 8, 2024
7f5f2a3
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 8, 2024
3bf5bee
raise explicit error when a column name cannot be supported
shobsi May 10, 2024
0149d59
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 10, 2024
b5f3232
keep literal_eval check on the serialization side to match
shobsi May 10, 2024
bad6df6
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-rf…
shobsi May 10, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

from __future__ import annotations

import ast
import dataclasses
import functools
import itertools
import os
import random
import textwrap
import typing
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
import warnings
Expand All @@ -44,8 +46,8 @@
import bigframes.core.join_def as join_defs
import bigframes.core.ordering as ordering
import bigframes.core.schema as bf_schema
import bigframes.core.sql as sql
import bigframes.core.tree_properties as tree_properties
import bigframes.core.utils
import bigframes.core.utils as utils
import bigframes.core.window_spec as window_specs
import bigframes.dtypes
Expand Down Expand Up @@ -1437,9 +1439,7 @@ def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
)

def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
axis_number = bigframes.core.utils.get_axis_number(
"rows" if (axis is None) else axis
)
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
for index_col in self._index_columns:
Expand All @@ -1460,9 +1460,7 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
return self.rename(columns=lambda label: f"{prefix}{label}")

def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
axis_number = bigframes.core.utils.get_axis_number(
"rows" if (axis is None) else axis
)
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
for index_col in self._index_columns:
Expand Down Expand Up @@ -2072,6 +2070,95 @@ def _is_monotonic(
self._stats_cache[column_name].update({op_name: result})
return result

def _get_rows_as_json_values(self) -> Block:
# We want to preserve any ordering currently present before turning to
# direct SQL manipulation. We will restore the ordering when we rebuild
# expression.
# TODO(shobs): Replace direct SQL manipulation by structured expression
# manipulation
ordering_column_name = guid.generate_guid()
expr = self.session._cache_with_offsets(self.expr)
expr = expr.promote_offsets(ordering_column_name)
expr_sql = self.session._to_sql(expr)

# Names of the columns to serialize for the row.
# We will use the repr-eval pattern to serialize a value here and
# deserialize in the cloud function. Let's make sure that would work.
column_names = []
for col in list(self.index_columns) + [col for col in self.column_labels]:
serialized_column_name = repr(col)
try:
ast.literal_eval(serialized_column_name)
except Exception:
raise NameError(
f"Column name type '{type(col).__name__}' is not supported for row serialization."
" Please consider using a name for which literal_eval(repr(name)) works."
)

column_names.append(serialized_column_name)
column_names_csv = sql.csv(column_names, quoted=True)

# index columns count
index_columns_count = len(self.index_columns)

# column references to form the array of values for the row
column_references_csv = sql.csv(
[sql.cast_as_string(col) for col in self.expr.column_ids]
)

# types of the columns to serialize for the row
column_types = list(self.index.dtypes) + list(self.dtypes)
column_types_csv = sql.csv([str(typ) for typ in column_types], quoted=True)

# row dtype to use for deserializing the row as pandas series
pandas_row_dtype = bigframes.dtypes.lcd_type(*column_types)
if pandas_row_dtype is None:
pandas_row_dtype = "object"
pandas_row_dtype = sql.quote(str(pandas_row_dtype))

# create a json column representing row through SQL manipulation
row_json_column_name = guid.generate_guid()
select_columns = (
[ordering_column_name] + list(self.index_columns) + [row_json_column_name]
)
select_columns_csv = sql.csv(
[sql.column_reference(col) for col in select_columns]
)
json_sql = f"""\
With T0 AS (
{textwrap.indent(expr_sql, " ")}
),
T1 AS (
SELECT *,
JSON_OBJECT(
"names", [{column_names_csv}],
"types", [{column_types_csv}],
"values", [{column_references_csv}],
"indexlength", {index_columns_count},
"dtype", {pandas_row_dtype}
) AS {row_json_column_name} FROM T0
)
SELECT {select_columns_csv} FROM T1
"""
ibis_table = self.session.ibis_client.sql(json_sql)
order_for_ibis_table = ordering.ExpressionOrdering.from_offset_col(
ordering_column_name
)
expr = core.ArrayValue.from_ibis(
self.session,
ibis_table,
[ibis_table[col] for col in select_columns if col != ordering_column_name],
hidden_ordering_columns=[ibis_table[ordering_column_name]],
ordering=order_for_ibis_table,
)
block = Block(
expr,
index_columns=self.index_columns,
column_labels=[row_json_column_name],
index_labels=self._index_labels,
)
return block


class BlockIndexProperties:
"""Accessor for the index-related block properties."""
Expand Down
59 changes: 59 additions & 0 deletions bigframes/core/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utility functions for SQL construction.
"""

from typing import Iterable


def quote(value: str):
"""Return quoted input string."""

# Let's use repr which also escapes any special characters
#
# >>> for val in [
# ... "123",
# ... "str with no special chars",
# ... "str with special chars.,'\"/\\"
# ... ]:
# ... print(f"{val} -> {repr(val)}")
# ...
# 123 -> '123'
# str with no special chars -> 'str with no special chars'
# str with special chars.,'"/\ -> 'str with special chars.,\'"/\\'

return repr(value)


def column_reference(column_name: str):
"""Return a string representing column reference in a SQL."""

return f"`{column_name}`"


def cast_as_string(column_name: str):
"""Return a string representing string casting of a column."""

return f"CAST({column_reference(column_name)} AS STRING)"


def csv(values: Iterable[str], quoted=False):
"""Return a string of comma separated values."""

if quoted:
values = [quote(val) for val in values]

return ", ".join(values)
56 changes: 55 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Tuple,
Union,
)
import warnings

import bigframes_vendored.pandas.core.frame as vendored_pandas_frame
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
Expand Down Expand Up @@ -61,6 +62,7 @@
import bigframes.core.window
import bigframes.core.window_spec as window_spec
import bigframes.dtypes
import bigframes.exceptions
import bigframes.formatting_helpers as formatter
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
Expand Down Expand Up @@ -3308,7 +3310,59 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None))
)

def apply(self, func, *, args: typing.Tuple = (), **kwargs):
def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
if utils.get_axis_number(axis) == 1:
warnings.warn(
"axis=1 scenario is in preview.",
category=bigframes.exceptions.PreviewWarning,
)

# Early check whether the dataframe dtypes are currently supported
# in the remote function
# NOTE: Keep in sync with the value converters used in the gcf code
# generated in generate_cloud_function_main_code in remote_function.py
remote_function_supported_dtypes = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd document these limitations, too, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding in #800

bigframes.dtypes.INT_DTYPE,
bigframes.dtypes.FLOAT_DTYPE,
bigframes.dtypes.BOOL_DTYPE,
bigframes.dtypes.STRING_DTYPE,
)
supported_dtypes_types = tuple(
type(dtype) for dtype in remote_function_supported_dtypes
)
supported_dtypes_hints = tuple(
str(dtype) for dtype in remote_function_supported_dtypes
)

for dtype in self.dtypes:
if not isinstance(dtype, supported_dtypes_types):
raise NotImplementedError(
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
f" Supported dtypes are {supported_dtypes_hints}."
)

# Check if the function is a remote function
if not hasattr(func, "bigframes_remote_function"):
raise ValueError("For axis=1 a remote function must be used.")

# Serialize the rows as json values
block = self._get_block()
rows_as_json_series = bigframes.series.Series(
block._get_rows_as_json_values()
)

# Apply the function
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(func=func, apply_on_null=True)
)
result_series.name = None

# Return Series with materialized result so that any error in the remote
# function is caught early
materialized_series = result_series.cache()
return materialized_series

# Per-column apply
results = {name: func(col, *args, **kwargs) for name, col in self.items()}
if all(
[
Expand Down
4 changes: 4 additions & 0 deletions bigframes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ class CleanupFailedWarning(Warning):

class DefaultIndexWarning(Warning):
"""Default index may cause unexpected costs."""


class PreviewWarning(Warning):
"""The feature is in preview."""
Loading