Skip to content

feat: add ml.preprocessing.MaxAbsScaler #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
import time
from typing import Optional
from typing import cast, Optional

import google.api_core.exceptions
from google.cloud import bigquery_connection_v1, resourcemanager_v3
Expand Down Expand Up @@ -80,6 +80,7 @@ def create_bq_connection(
logger.info(
f"Created BQ connection {connection_name} with service account id: {service_account_id}"
)
service_account_id = cast(str, service_account_id)
# Ensure IAM role on the BQ connection
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#grant_permission_on_function
self._ensure_iam_binding(project_id, service_account_id, iam_role)
Expand Down
1 change: 1 addition & 0 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CompilablePreprocessorType = Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.MaxAbsScaler,
preprocessing.LabelEncoder,
]

Expand Down
10 changes: 10 additions & 0 deletions bigframes/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
compose.ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.MaxAbsScaler,
preprocessing.LabelEncoder,
),
):
Expand Down Expand Up @@ -147,6 +148,7 @@ def _extract_as_column_transformer(
Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.MaxAbsScaler,
preprocessing.LabelEncoder,
],
Union[str, List[str]],
Expand All @@ -172,6 +174,13 @@ def _extract_as_column_transformer(
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.MAX_ABS_SCALER"):
transformers.append(
(
"max_abs_encoder",
*preprocessing.MaxAbsScaler._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.LABEL_ENCODER"):
transformers.append(
(
Expand All @@ -193,6 +202,7 @@ def _merge_column_transformer(
compose.ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.MaxAbsScaler,
preprocessing.LabelEncoder,
]:
"""Try to merge the column transformer to a simple transformer."""
Expand Down
86 changes: 84 additions & 2 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
Returns: a list of tuples of (sql_expression, output_name)"""
return [
(
self._base_sql_generator.ml_standard_scaler(column, f"scaled_{column}"),
f"scaled_{column}",
self._base_sql_generator.ml_standard_scaler(
column, f"standard_scaled_{column}"
),
f"standard_scaled_{column}",
)
for column in columns
]
Expand Down Expand Up @@ -105,6 +107,86 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


class MaxAbsScaler(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler,
):
__doc__ = (
third_party.bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler.__doc__
)

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._base_sql_generator = globals.base_sql_generator()

# TODO(garrettwu): implement __hash__
def __eq__(self, other: Any) -> bool:
return type(other) is MaxAbsScaler and self._bqml_model == other._bqml_model

def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
"""Compile this transformer to a list of SQL expressions that can be included in
a BQML TRANSFORM clause

Args:
columns: a list of column names to transform

Returns: a list of tuples of (sql_expression, output_name)"""
return [
(
self._base_sql_generator.ml_max_abs_scaler(
column, f"max_abs_scaled_{column}"
),
f"max_abs_scaled_{column}",
)
for column in columns
]

@classmethod
def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]:
"""Parse SQL to tuple(StandardScaler, column_label).

Args:
sql: SQL string of format "ML.MAX_ABS_SCALER({col_label}) OVER()"

Returns:
tuple(StandardScaler, column_label)"""
col_label = sql[sql.find("(") + 1 : sql.find(")")]
return cls(), col_label

def fit(
self,
X: Union[bpd.DataFrame, bpd.Series],
y=None, # ignored
) -> MaxAbsScaler:
(X,) = utils.convert_to_dataframe(X)

compiled_transforms = self._compile_to_sql(X.columns.tolist())
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]

self._bqml_model = self._bqml_model_factory.create_model(
X,
options={"model_type": "transform_only"},
transforms=transform_sqls,
)

# The schema of TRANSFORM output is not available in the model API, so save it during fitting
self._output_names = [name for _, name in compiled_transforms]
return self

def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("Must be fitted before transform")

(X,) = utils.convert_to_dataframe(X)

df = self._bqml_model.transform(X)
return typing.cast(
bpd.DataFrame,
df[self._output_names],
)


class OneHotEncoder(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._encoder.OneHotEncoder,
Expand Down
4 changes: 4 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def ml_standard_scaler(self, numeric_expr_sql: str, name: str) -> str:
"""Encode ML.STANDARD_SCALER for BQML"""
return f"""ML.STANDARD_SCALER({numeric_expr_sql}) OVER() AS {name}"""

def ml_max_abs_scaler(self, numeric_expr_sql: str, name: str) -> str:
"""Encode ML.MAX_ABS_SCALER for BQML"""
return f"""ML.MAX_ABS_SCALER({numeric_expr_sql}) OVER() AS {name}"""

def ml_one_hot_encoder(
self,
numeric_expr_sql: str,
Expand Down
24 changes: 12 additions & 12 deletions tests/system/large/ml/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,20 @@ def test_columntransformer_standalone_fit_and_transform(
[{"index": 1, "value": 1.0}],
[{"index": 2, "value": 1.0}],
],
"scaled_culmen_length_mm": [
"standard_scaled_culmen_length_mm": [
-0.811119671289163,
-0.9945520581113803,
-1.104611490204711,
],
"scaled_flipper_length_mm": [-0.350044, -1.418336, -0.9198],
"standard_scaled_flipper_length_mm": [-0.350044, -1.418336, -0.9198],
},
index=pandas.Index([1633, 1672, 1690], dtype="Int64", name="tag_number"),
)
expected.scaled_culmen_length_mm = expected.scaled_culmen_length_mm.astype(
"Float64"
expected.standard_scaled_culmen_length_mm = (
expected.standard_scaled_culmen_length_mm.astype("Float64")
)
expected.scaled_flipper_length_mm = expected.scaled_flipper_length_mm.astype(
"Float64"
expected.standard_scaled_flipper_length_mm = (
expected.standard_scaled_flipper_length_mm.astype("Float64")
)

pandas.testing.assert_frame_equal(result, expected, rtol=1e-3)
Expand Down Expand Up @@ -107,20 +107,20 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
[{"index": 1, "value": 1.0}],
[{"index": 2, "value": 1.0}],
],
"scaled_culmen_length_mm": [
"standard_scaled_culmen_length_mm": [
1.313249,
-0.20198,
-1.111118,
],
"scaled_flipper_length_mm": [1.251098, -1.196588, -0.054338],
"standard_scaled_flipper_length_mm": [1.251098, -1.196588, -0.054338],
},
index=pandas.Index([1633, 1672, 1690], dtype="Int64", name="tag_number"),
)
expected.scaled_culmen_length_mm = expected.scaled_culmen_length_mm.astype(
"Float64"
expected.standard_scaled_culmen_length_mm = (
expected.standard_scaled_culmen_length_mm.astype("Float64")
)
expected.scaled_flipper_length_mm = expected.scaled_flipper_length_mm.astype(
"Float64"
expected.standard_scaled_flipper_length_mm = (
expected.standard_scaled_flipper_length_mm.astype("Float64")
)

pandas.testing.assert_frame_equal(result, expected, rtol=1e-3)
82 changes: 58 additions & 24 deletions tests/system/large/ml/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,15 @@ def test_pipeline_columntransformer_fit_predict(session, penguins_df_default_ind
"species",
),
(
"scale",
"standard_scale",
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"max_abs_scale",
preprocessing.MaxAbsScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"label",
preprocessing.LabelEncoder(),
Expand Down Expand Up @@ -637,6 +642,11 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"max_abs_scale",
preprocessing.MaxAbsScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"label",
preprocessing.LabelEncoder(),
Expand All @@ -660,30 +670,26 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id

assert isinstance(pl_loaded._transform, compose.ColumnTransformer)
transformers = pl_loaded._transform.transformers_
assert len(transformers) == 4

assert transformers[0][0] == "ont_hot_encoder"
assert isinstance(transformers[0][1], preprocessing.OneHotEncoder)
one_hot_encoder = transformers[0][1]
assert one_hot_encoder.drop == "most_frequent"
assert one_hot_encoder.min_frequency == 5
assert one_hot_encoder.max_categories == 100
assert transformers[0][2] == "species"

assert transformers[1][0] == "label_encoder"
assert isinstance(transformers[1][1], preprocessing.LabelEncoder)
one_hot_encoder = transformers[1][1]
assert one_hot_encoder.min_frequency == 0
assert one_hot_encoder.max_categories == 1000001
assert transformers[1][2] == "species"

assert transformers[2][0] == "standard_scaler"
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
assert transformers[2][2] == "culmen_length_mm"
expected = [
(
"ont_hot_encoder",
preprocessing.OneHotEncoder(
drop="most_frequent", max_categories=100, min_frequency=5
),
"species",
),
(
"label_encoder",
preprocessing.LabelEncoder(max_categories=1000001, min_frequency=0),
"species",
),
("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"),
("max_abs_encoder", preprocessing.MaxAbsScaler(), "culmen_length_mm"),
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
("max_abs_encoder", preprocessing.MaxAbsScaler(), "flipper_length_mm"),
]

assert transformers[3][0] == "standard_scaler"
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
assert transformers[3][2] == "flipper_length_mm"
assert transformers == expected

assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
assert pl_loaded._estimator.fit_intercept is False
Expand Down Expand Up @@ -717,6 +723,34 @@ def test_pipeline_standard_scaler_to_gbq(penguins_df_default_index, dataset_id):
assert pl_loaded._estimator.fit_intercept is False


def test_pipeline_max_abs_scaler_to_gbq(penguins_df_default_index, dataset_id):
pl = pipeline.Pipeline(
[
("transform", preprocessing.MaxAbsScaler()),
("estimator", linear_model.LinearRegression(fit_intercept=False)),
]
)

df = penguins_df_default_index.dropna()
X_train = df[
[
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
]
]
y_train = df[["body_mass_g"]]
pl.fit(X_train, y_train)

pl_loaded = pl.to_gbq(
f"{dataset_id}.test_penguins_pipeline_standard_scaler", replace=True
)
assert isinstance(pl_loaded._transform, preprocessing.MaxAbsScaler)

assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
assert pl_loaded._estimator.fit_intercept is False


def test_pipeline_one_hot_encoder_to_gbq(penguins_df_default_index, dataset_id):
pl = pipeline.Pipeline(
[
Expand Down
Loading