Skip to content

feat: allow multiple columns input for llm models #998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 25, 2024
70 changes: 25 additions & 45 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def predict(

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame or Series, which contains only one column of prompts.
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
Prompts can include preamble, questions, suggestions, instructions, or examples.

temperature (float, default 0.0):
Expand Down Expand Up @@ -307,14 +307,10 @@ def predict(

(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})
if len(X.columns) == 1:
Copy link
Contributor

@shobsi shobsi Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make another check in the else clause - that the multi-column input does have a "prompt" column. Also add negative test for that scenario

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tswast had a suggestion that we shouldn't do much client side checks. I'm trying to follow: if the error message is meaningful to the user, then rely on server side checks. Otherwise we have to wrap server error messages or return client side error messages.

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
"temperature": temperature,
Expand Down Expand Up @@ -522,7 +518,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.

Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
Expand All @@ -531,14 +527,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})
if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

options = {
"flatten_json_output": True,
Expand Down Expand Up @@ -679,7 +671,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.

Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
Expand All @@ -688,14 +680,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})
if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

options = {
"flatten_json_output": True,
Expand Down Expand Up @@ -893,7 +881,7 @@ def predict(

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame or Series, which contains only one column of prompts.
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
Prompts can include preamble, questions, suggestions, instructions, or examples.

temperature (float, default 0.9):
Expand Down Expand Up @@ -938,14 +926,10 @@ def predict(

(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})
if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
"temperature": temperature,
Expand Down Expand Up @@ -1181,7 +1165,7 @@ def predict(

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame or Series, which contains only one column of prompts.
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
Prompts can include preamble, questions, suggestions, instructions, or examples.

max_output_tokens (int, default 128):
Expand Down Expand Up @@ -1222,14 +1206,10 @@ def predict(

(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})
if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
"max_output_tokens": max_output_tokens,
Expand Down
24 changes: 24 additions & 0 deletions tests/system/load/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,27 @@ def test_claude3_text_generator_predict_with_params_success(
utils.check_pandas_df_schema_and_index(
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
)


@pytest.mark.parametrize(
"model_name",
("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"),
)
@pytest.mark.flaky(retries=3, delay=120)
def test_claude3_text_generator_predict_multi_col_success(
llm_text_df, model_name, session, session_us_east5, bq_connection
):
if model_name in ("claude-3-5-sonnet", "claude-3-opus"):
session = session_us_east5

llm_text_df["additional_col"] = 1
claude3_text_generator_model = llm.Claude3TextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
df = claude3_text_generator_model.predict(llm_text_df).to_pandas()
utils.check_pandas_df_schema_and_index(
df,
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
index=3,
col_exact=False,
)
74 changes: 69 additions & 5 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from bigframes.ml import llm
import bigframes.pandas as bpd
from tests.system import utils


Expand Down Expand Up @@ -166,6 +167,20 @@ def test_text_generator_predict_arbitrary_col_label_success(
)


@pytest.mark.flaky(retries=2)
def test_text_generator_predict_multiple_cols_success(
palm2_text_generator_model, llm_text_df: bpd.DataFrame
):
df = llm_text_df.assign(additional_col=1)
pd_df = palm2_text_generator_model.predict(df).to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
index=3,
col_exact=False,
)


@pytest.mark.flaky(retries=2)
def test_text_generator_predict_with_params_success(
palm2_text_generator_model, llm_text_df
Expand Down Expand Up @@ -212,11 +227,33 @@ def test_text_embedding_generator_predict_default_params_success(
model_name=model_name, connection_name=bq_connection, session=session
)
df = text_embedding_model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_embedding_result" in df.columns
series = df["ml_generate_embedding_result"]
value = series[0]
assert len(value) == 768
utils.check_pandas_df_schema_and_index(
df, columns=utils.ML_GENERATE_EMBEDDING_OUTPUT, index=3, col_exact=False
)
assert len(df["ml_generate_embedding_result"][0]) == 768


@pytest.mark.parametrize(
"model_name",
("text-embedding-004", "text-multilingual-embedding-002"),
)
@pytest.mark.flaky(retries=2)
def test_text_embedding_generator_multi_cols_predict_success(
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
):
df = llm_text_df.assign(additional_col=1)
df = df.rename(columns={"prompt": "content"})
text_embedding_model = llm.TextEmbeddingGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
pd_df = text_embedding_model.predict(df).to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=utils.ML_GENERATE_EMBEDDING_OUTPUT + ["additional_col"],
index=3,
col_exact=False,
)
assert len(pd_df["ml_generate_embedding_result"][0]) == 768


@pytest.mark.parametrize(
Expand Down Expand Up @@ -295,6 +332,33 @@ def test_gemini_text_generator_predict_with_params_success(
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-pro",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_multi_cols_predict_success(
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
):
df = llm_text_df.assign(additional_col=1)
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
pd_df = gemini_text_generator_model.predict(df).to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
index=3,
col_exact=False,
)


@pytest.mark.flaky(retries=2)
def test_llm_palm_score(llm_fine_tune_df_default_index):
model = llm.PaLM2TextGenerator(model_name="text-bison")
Expand Down
6 changes: 6 additions & 0 deletions tests/system/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
"ml_generate_text_status",
"prompt",
]
ML_GENERATE_EMBEDDING_OUTPUT = [
"ml_generate_embedding_result",
"ml_generate_embedding_statistics",
"ml_generate_embedding_status",
"content",
]


def skip_legacy_pandas(test):
Expand Down