Skip to content

Commit 1930b4e

Browse files
ferenc-hechlerFerenc Hechlertswastgcf-owl-bot[bot]
authored
feat: add bigframes.ml.compose.SQLScalarColumnTransformer to create custom SQL-based transformations (#955)
* Add support for custom transformers (not ML.) in ColumnTransformer. * allow numbers in Custom-Transformer-IDs. * comment was moved to the end of the sql. * Do not offer the feedback link for missing custom transformers. * cleanup typing hints. * Add unit tests for CustomTransformer. * added unit tests for _extract_output_names() and _compile_to_sql(). * run black and flake8 linter. * fixed wrong @classmethod annotation. * on the way to SQLScalarColumnTransformer * remove pytest.main call. * remove CustomTransformer class and implementations. * fix typing. * fix typing. * fixed mock typing. * replace _NameClass. * black formating. * add traget_column as input_column with a "?" prefix when parsing SQLScalarColumnTransformer from sql. * reformatted with black version 22.3.0. * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * remove eclipse project files * SQLScalarColumnTransformer needs not to be inherited from base.BaseTransformer. * remove filter for "ML." sqls in _extract_output_names() of BaseTransformer * introduced type hint SingleColTransformer for transformers contained in ColumnTransformer * make sql and target_column private in SQLScalarColumnTransformer * Add documentation for SQLScalarColumnTransformer. * add first system test for SQLScalarColumnTransformer. * SQLScalarColumnTransformer system tests for fit-transform and save-load * make SQLScalarColumnTransformer comparable (equals) for comparing sets in tests * implement hash and eq (copied from BaseTransformer) * undo accidentally checked in files * remove eclipse settings accidentally checked in. * fix docs. * Update bigframes/ml/compose.py * Update bigframes/ml/compose.py * add support for flexible column names. * remove main. * add system test for output column with flexible column name * system tests: add new flexible output column to check-df-schema. * Apply suggestions from code review --------- Co-authored-by: Ferenc Hechler <ferenc.hechler@telekom.de> Co-authored-by: Tim Sweña (Swast) <swast@google.com> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 3c54399 commit 1930b4e

File tree

4 files changed

+633
-14
lines changed

4 files changed

+633
-14
lines changed

‎bigframes/ml/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,6 @@ def _extract_output_names(self):
198198
# pass the columns that are not transformed
199199
if "transformSql" not in transform_col_dict:
200200
continue
201-
transform_sql: str = transform_col_dict["transformSql"]
202-
if not transform_sql.startswith("ML."):
203-
continue
204-
205201
output_names.append(transform_col_dict["name"])
206202

207203
self._output_names = output_names

‎bigframes/ml/compose.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,101 @@
4646
)
4747

4848

49+
class SQLScalarColumnTransformer:
50+
r"""
51+
Wrapper for plain SQL code contained in a ColumnTransformer.
52+
53+
Create a single column transformer in plain sql.
54+
This transformer can only be used inside ColumnTransformer.
55+
56+
When creating an instance '{0}' can be used as placeholder
57+
for the column to transform:
58+
59+
SQLScalarColumnTransformer("{0}+1")
60+
61+
The default target column gets the prefix 'transformed\_'
62+
but can also be changed when creating an instance:
63+
64+
SQLScalarColumnTransformer("{0}+1", "inc_{0}")
65+
66+
**Examples:**
67+
68+
>>> from bigframes.ml.compose import ColumnTransformer, SQLScalarColumnTransformer
69+
>>> import bigframes.pandas as bpd
70+
<BLANKLINE>
71+
>>> df = bpd.DataFrame({'name': ["James", None, "Mary"], 'city': ["New York", "Boston", None]})
72+
>>> col_trans = ColumnTransformer([
73+
... ("strlen",
74+
... SQLScalarColumnTransformer("CASE WHEN {0} IS NULL THEN 15 ELSE LENGTH({0}) END"),
75+
... ['name', 'city']),
76+
... ])
77+
>>> col_trans = col_trans.fit(df)
78+
>>> df_transformed = col_trans.transform(df)
79+
>>> df_transformed
80+
transformed_name transformed_city
81+
0 5 8
82+
1 15 6
83+
2 4 15
84+
<BLANKLINE>
85+
[3 rows x 2 columns]
86+
87+
SQLScalarColumnTransformer can be combined with other transformers, like StandardScaler:
88+
89+
>>> col_trans = ColumnTransformer([
90+
... ("identity", SQLScalarColumnTransformer("{0}", target_column="{0}"), ["col1", "col5"]),
91+
... ("increment", SQLScalarColumnTransformer("{0}+1", target_column="inc_{0}"), "col2"),
92+
... ("stdscale", preprocessing.StandardScaler(), "col3"),
93+
... # ...
94+
... ])
95+
96+
"""
97+
98+
def __init__(self, sql: str, target_column: str = "transformed_{0}"):
99+
super().__init__()
100+
self._sql = sql
101+
self._target_column = target_column.replace("`", "")
102+
103+
PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE)
104+
105+
def escape(self, colname: str):
106+
colname = colname.replace("`", "")
107+
if self.PLAIN_COLNAME_RX.match(colname):
108+
return colname
109+
return f"`{colname}`"
110+
111+
def _compile_to_sql(
112+
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
113+
) -> List[str]:
114+
if columns is None:
115+
columns = X.columns
116+
result = []
117+
for column in columns:
118+
current_sql = self._sql.format(self.escape(column))
119+
current_target_column = self.escape(self._target_column.format(column))
120+
result.append(f"{current_sql} AS {current_target_column}")
121+
return result
122+
123+
def __repr__(self):
124+
return f"SQLScalarColumnTransformer(sql='{self._sql}', target_column='{self._target_column}')"
125+
126+
def __eq__(self, other) -> bool:
127+
return type(self) is type(other) and self._keys() == other._keys()
128+
129+
def __hash__(self) -> int:
130+
return hash(self._keys())
131+
132+
def _keys(self):
133+
return (self._sql, self._target_column)
134+
135+
136+
# Type hints for transformers contained in ColumnTransformer
137+
SingleColTransformer = Union[
138+
preprocessing.PreprocessingType,
139+
impute.SimpleImputer,
140+
SQLScalarColumnTransformer,
141+
]
142+
143+
49144
@log_adapter.class_logger
50145
class ColumnTransformer(
51146
base.Transformer,
@@ -60,7 +155,7 @@ def __init__(
60155
transformers: Iterable[
61156
Tuple[
62157
str,
63-
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
158+
SingleColTransformer,
64159
Union[str, Iterable[str]],
65160
]
66161
],
@@ -78,14 +173,12 @@ def _keys(self):
78173
@property
79174
def transformers_(
80175
self,
81-
) -> List[
82-
Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str]
83-
]:
176+
) -> List[Tuple[str, SingleColTransformer, str,]]:
84177
"""The collection of transformers as tuples of (name, transformer, column)."""
85178
result: List[
86179
Tuple[
87180
str,
88-
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
181+
SingleColTransformer,
89182
str,
90183
]
91184
] = []
@@ -103,6 +196,8 @@ def transformers_(
103196

104197
return result
105198

199+
AS_FLEXNAME_SUFFIX_RX = re.compile("^(.*)\\bAS\\s*`[^`]+`\\s*$", re.IGNORECASE)
200+
106201
@classmethod
107202
def _extract_from_bq_model(
108203
cls,
@@ -114,7 +209,7 @@ def _extract_from_bq_model(
114209
transformers_set: Set[
115210
Tuple[
116211
str,
117-
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
212+
SingleColTransformer,
118213
Union[str, List[str]],
119214
]
120215
] = set()
@@ -130,8 +225,11 @@ def camel_to_snake(name):
130225
if "transformSql" not in transform_col_dict:
131226
continue
132227
transform_sql: str = transform_col_dict["transformSql"]
133-
if not transform_sql.startswith("ML."):
134-
continue
228+
229+
# workaround for bug in bq_model returning " AS `...`" suffix for flexible names
230+
flex_name_match = cls.AS_FLEXNAME_SUFFIX_RX.match(transform_sql)
231+
if flex_name_match:
232+
transform_sql = flex_name_match.group(1)
135233

136234
output_names.append(transform_col_dict["name"])
137235
found_transformer = False
@@ -148,8 +246,22 @@ def camel_to_snake(name):
148246
found_transformer = True
149247
break
150248
if not found_transformer:
151-
raise NotImplementedError(
152-
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
249+
if transform_sql.startswith("ML."):
250+
raise NotImplementedError(
251+
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
252+
)
253+
254+
target_column = transform_col_dict["name"]
255+
sql_transformer = SQLScalarColumnTransformer(
256+
transform_sql, target_column=target_column
257+
)
258+
input_column_name = f"?{target_column}"
259+
transformers_set.add(
260+
(
261+
camel_to_snake(sql_transformer.__class__.__name__),
262+
sql_transformer,
263+
input_column_name,
264+
)
153265
)
154266

155267
transformer = cls(transformers=list(transformers_set))
@@ -167,6 +279,8 @@ def _merge(
167279

168280
assert len(transformers) > 0
169281
_, transformer_0, column_0 = transformers[0]
282+
if isinstance(transformer_0, SQLScalarColumnTransformer):
283+
return self # SQLScalarColumnTransformer only work inside ColumnTransformer
170284
feature_columns_sorted = sorted(
171285
[
172286
cast(str, feature_column.name)

‎tests/system/large/ml/test_compose.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@ def test_columntransformer_standalone_fit_and_transform(
3636
preprocessing.MinMaxScaler(),
3737
["culmen_length_mm"],
3838
),
39+
(
40+
"increment",
41+
compose.SQLScalarColumnTransformer("{0}+1"),
42+
["culmen_length_mm", "flipper_length_mm"],
43+
),
44+
(
45+
"length",
46+
compose.SQLScalarColumnTransformer(
47+
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
48+
target_column="len_{0}",
49+
),
50+
"species",
51+
),
52+
(
53+
"ohe",
54+
compose.SQLScalarColumnTransformer(
55+
"CASE WHEN {0}='Adelie Penguin (Pygoscelis adeliae)' THEN 1 ELSE 0 END",
56+
target_column="ohe_adelie",
57+
),
58+
"species",
59+
),
60+
(
61+
"identity",
62+
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
63+
["culmen_length_mm", "flipper_length_mm"],
64+
),
3965
]
4066
)
4167

@@ -51,6 +77,12 @@ def test_columntransformer_standalone_fit_and_transform(
5177
"standard_scaled_culmen_length_mm",
5278
"min_max_scaled_culmen_length_mm",
5379
"standard_scaled_flipper_length_mm",
80+
"transformed_culmen_length_mm",
81+
"transformed_flipper_length_mm",
82+
"len_species",
83+
"ohe_adelie",
84+
"culmen_length_mm",
85+
"flipper_length_mm",
5486
],
5587
index=[1633, 1672, 1690],
5688
col_exact=False,
@@ -70,6 +102,19 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
70102
preprocessing.StandardScaler(),
71103
["culmen_length_mm", "flipper_length_mm"],
72104
),
105+
(
106+
"length",
107+
compose.SQLScalarColumnTransformer(
108+
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
109+
target_column="len_{0}",
110+
),
111+
"species",
112+
),
113+
(
114+
"identity",
115+
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
116+
["culmen_length_mm", "flipper_length_mm"],
117+
),
73118
]
74119
)
75120

@@ -83,6 +128,9 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
83128
"onehotencoded_species",
84129
"standard_scaled_culmen_length_mm",
85130
"standard_scaled_flipper_length_mm",
131+
"len_species",
132+
"culmen_length_mm",
133+
"flipper_length_mm",
86134
],
87135
index=[1633, 1672, 1690],
88136
col_exact=False,
@@ -102,6 +150,27 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
102150
preprocessing.StandardScaler(),
103151
["culmen_length_mm", "flipper_length_mm"],
104152
),
153+
(
154+
"length",
155+
compose.SQLScalarColumnTransformer(
156+
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
157+
target_column="len_{0}",
158+
),
159+
"species",
160+
),
161+
(
162+
"identity",
163+
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
164+
["culmen_length_mm", "flipper_length_mm"],
165+
),
166+
(
167+
"flexname",
168+
compose.SQLScalarColumnTransformer(
169+
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
170+
target_column="Flex {0} Name",
171+
),
172+
"species",
173+
),
105174
]
106175
)
107176
transformer.fit(
@@ -122,6 +191,36 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
122191
),
123192
("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"),
124193
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
194+
(
195+
"sql_scalar_column_transformer",
196+
compose.SQLScalarColumnTransformer(
197+
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END",
198+
target_column="len_species",
199+
),
200+
"?len_species",
201+
),
202+
(
203+
"sql_scalar_column_transformer",
204+
compose.SQLScalarColumnTransformer(
205+
"flipper_length_mm", target_column="flipper_length_mm"
206+
),
207+
"?flipper_length_mm",
208+
),
209+
(
210+
"sql_scalar_column_transformer",
211+
compose.SQLScalarColumnTransformer(
212+
"culmen_length_mm", target_column="culmen_length_mm"
213+
),
214+
"?culmen_length_mm",
215+
),
216+
(
217+
"sql_scalar_column_transformer",
218+
compose.SQLScalarColumnTransformer(
219+
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END ",
220+
target_column="Flex species Name",
221+
),
222+
"?Flex species Name",
223+
),
125224
]
126225
assert set(reloaded_transformer.transformers) == set(expected)
127226
assert reloaded_transformer._bqml_model is not None
@@ -136,6 +235,10 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
136235
"onehotencoded_species",
137236
"standard_scaled_culmen_length_mm",
138237
"standard_scaled_flipper_length_mm",
238+
"len_species",
239+
"culmen_length_mm",
240+
"flipper_length_mm",
241+
"Flex species Name",
139242
],
140243
index=[1633, 1672, 1690],
141244
col_exact=False,

0 commit comments

Comments
 (0)