Skip to content

Commit e19f7ac

Browse files
committed
add predict tests
1 parent 7f89428 commit e19f7ac

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

‎tests/system/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,20 @@ def llm_fine_tune_df_default_index(
550550
return session.read_gbq(sql)
551551

552552

553+
@pytest.fixture(scope="session")
554+
def llm_remote_text_pandas_df():
555+
"""Additional data matching the penguins dataset, with a new index"""
556+
return pd.DataFrame(
557+
{
558+
"prompt": [
559+
"Please do sentiment analysis on the following text and only output a number from 0 to 5where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i feel beautifully emotional knowing that these women of whom i knew just a handful were holding me and my baba on our journey",
560+
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i was feeling a little vain when i did this one",
561+
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: a father of children killed in an accident",
562+
],
563+
}
564+
)
565+
566+
553567
@pytest.fixture(scope="session")
554568
def time_series_df_default_index(
555569
time_series_table_id: str, session: bigframes.Session

‎tests/system/load/test_llm.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import bigframes.ml.llm
1616

1717

18-
def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, dataset_id):
18+
def test_llm_palm_configure_fit(
19+
llm_fine_tune_df_default_index, llm_remote_text_pandas_df
20+
):
1921
model = bigframes.ml.llm.PaLM2TextGenerator(
2022
model_name="text-bison", max_iterations=1, evaluation_task="CLASSIFICATION"
2123
)
@@ -25,12 +27,12 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, dataset_id):
2527
y_train = df[["label"]]
2628
model.fit(X_train, y_train)
2729

28-
# save, load, check parameters to ensure configuration was kept
29-
reloaded_model = model.to_gbq(
30-
f"{dataset_id}.temp_configured_palm_model", replace=True
31-
)
32-
assert (
33-
f"{dataset_id}.temp_configured_palm_model"
34-
in reloaded_model._bqml_model.model_name
35-
)
36-
assert reloaded_model.evaluation_task == "CLASSIFICATION"
30+
assert model is not None
31+
32+
df = model.predict(llm_remote_text_pandas_df).to_pandas()
33+
assert df.shape == (3, 4)
34+
assert "ml_generate_text_llm_result" in df.columns
35+
series = df["ml_generate_text_llm_result"]
36+
assert all(series.str.len() == 1)
37+
38+
# TODO(ashleyxu): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept

‎tests/system/small/ml/conftest.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,6 @@ def palm2_text_generator_model(session, bq_connection) -> llm.PaLM2TextGenerator
232232
return llm.PaLM2TextGenerator(session=session, connection_name=bq_connection)
233233

234234

235-
@pytest.fixture(scope="session")
236-
def palm2_text_generator_fine_tune_model(
237-
session, bq_connection
238-
) -> llm.PaLM2TextGenerator:
239-
return llm.PaLM2TextGenerator(
240-
session=session,
241-
connection_name=bq_connection,
242-
max_iterations=300,
243-
evaluation_task="TEXT_GENERATION",
244-
)
245-
246-
247235
@pytest.fixture(scope="session")
248236
def palm2_text_generator_32k_model(session, bq_connection) -> llm.PaLM2TextGenerator:
249237
return llm.PaLM2TextGenerator(

0 commit comments

Comments
 (0)