Skip to content

Commit ca02cd4

Browse files
authored
feat: add the ground_with_google_search option for GeminiTextGenerator predict (#1119)
* feat: add option for GeminiTextGenerator predict * add pricing link to the warning message
1 parent b5ca1d9 commit ca02cd4

File tree

3 files changed

+164
-21
lines changed

3 files changed

+164
-21
lines changed

‎bigframes/ml/llm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ def predict(
913913
max_output_tokens: int = 8192,
914914
top_k: int = 40,
915915
top_p: float = 1.0,
916+
ground_with_google_search: bool = False,
916917
) -> bpd.DataFrame:
917918
"""Predict the result from input DataFrame.
918919
@@ -936,11 +937,20 @@ def predict(
936937
Specify a lower value for less random responses and a higher value for more random responses.
937938
Default 40. Possible values [1, 40].
938939
939-
top_p (float, default 0.95)::
940+
top_p (float, default 0.95):
940941
Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.
941942
Specify a lower value for less random responses and a higher value for more random responses.
942943
Default 1.0. Possible values [0.0, 1.0].
943944
945+
ground_with_google_search (bool, default False):
946+
Enables Grounding with Google Search for the Vertex AI model. When set
947+
to True, the model incorporates relevant information from Google Search
948+
results into its responses, enhancing their accuracy and factualness.
949+
This feature provides an additional column, `ml_generate_text_grounding_result`,
950+
in the response output, detailing the sources used for grounding.
951+
Note: Using this feature may impact billing costs. Refer to the pricing
952+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
953+
The default is `False`.
944954
945955
Returns:
946956
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -974,6 +984,7 @@ def predict(
974984
"top_k": top_k,
975985
"top_p": top_p,
976986
"flatten_json_output": True,
987+
"ground_with_google_search": ground_with_google_search,
977988
}
978989

979990
df = self._bqml_model.generate_text(X, options)

‎bigframes/operations/semantics.py

Lines changed: 132 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import typing
1818
from typing import List, Optional
19+
import warnings
1920

2021
import numpy as np
2122

@@ -39,6 +40,7 @@ def agg(
3940
model,
4041
cluster_column: typing.Optional[str] = None,
4142
max_agg_rows: int = 10,
43+
ground_with_google_search: bool = False,
4244
):
4345
"""
4446
Performs an aggregation over all rows of the table.
@@ -90,6 +92,14 @@ def agg(
9092
max_agg_rows (int, default 10):
9193
The maxinum number of rows to be aggregated at a time.
9294
95+
ground_with_google_search (bool, default False):
96+
Enables Grounding with Google Search for the GeminiTextGenerator model.
97+
When set to True, the model incorporates relevant information from Google
98+
Search results into its responses, enhancing their accuracy and factualness.
99+
Note: Using this feature may impact billing costs. Refer to the pricing
100+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
101+
The default is `False`.
102+
93103
Returns:
94104
bigframes.dataframe.DataFrame: A new DataFrame with the aggregated answers.
95105
@@ -119,6 +129,12 @@ def agg(
119129
)
120130
column = columns[0]
121131

132+
if ground_with_google_search:
133+
warnings.warn(
134+
"Enables Grounding with Google Search may impact billing cost. See pricing "
135+
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
136+
)
137+
122138
if max_agg_rows <= 1:
123139
raise ValueError(
124140
f"Invalid value for `max_agg_rows`: {max_agg_rows}."
@@ -191,7 +207,12 @@ def agg(
191207

192208
# Run model
193209
predict_df = typing.cast(
194-
bigframes.dataframe.DataFrame, model.predict(prompt_s, temperature=0.0)
210+
bigframes.dataframe.DataFrame,
211+
model.predict(
212+
prompt_s,
213+
temperature=0.0,
214+
ground_with_google_search=ground_with_google_search,
215+
),
195216
)
196217
agg_df[column] = predict_df["ml_generate_text_llm_result"].combine_first(
197218
single_row_df
@@ -284,7 +305,7 @@ def cluster_by(
284305
df[output_column] = clustered_result["CENTROID_ID"]
285306
return df
286307

287-
def filter(self, instruction: str, model):
308+
def filter(self, instruction: str, model, ground_with_google_search: bool = False):
288309
"""
289310
Filters the DataFrame with the semantics of the user instruction.
290311
@@ -305,18 +326,26 @@ def filter(self, instruction: str, model):
305326
[1 rows x 2 columns]
306327
307328
Args:
308-
instruction:
329+
instruction (str):
309330
An instruction on how to filter the data. This value must contain
310331
column references by name, which should be wrapped in a pair of braces.
311332
For example, if you have a column "food", you can refer to this column
312333
in the instructions like:
313334
"The {food} is healthy."
314335
315-
model:
336+
model (bigframes.ml.llm.GeminiTextGenerator):
316337
A GeminiTextGenerator provided by Bigframes ML package.
317338
339+
ground_with_google_search (bool, default False):
340+
Enables Grounding with Google Search for the GeminiTextGenerator model.
341+
When set to True, the model incorporates relevant information from Google
342+
Search results into its responses, enhancing their accuracy and factualness.
343+
Note: Using this feature may impact billing costs. Refer to the pricing
344+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
345+
The default is `False`.
346+
318347
Returns:
319-
DataFrame filtered by the instruction.
348+
bigframes.pandas.DataFrame: DataFrame filtered by the instruction.
320349
321350
Raises:
322351
NotImplementedError: when the semantic operator experiment is off.
@@ -332,6 +361,12 @@ def filter(self, instruction: str, model):
332361
if column not in self._df.columns:
333362
raise ValueError(f"Column {column} not found.")
334363

364+
if ground_with_google_search:
365+
warnings.warn(
366+
"Enables Grounding with Google Search may impact billing cost. See pricing "
367+
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
368+
)
369+
335370
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
336371
for column in columns:
337372
if df[column].dtype != dtypes.STRING_DTYPE:
@@ -345,14 +380,21 @@ def filter(self, instruction: str, model):
345380
model.predict(
346381
self._make_prompt(df, columns, user_instruction, output_instruction),
347382
temperature=0.0,
383+
ground_with_google_search=ground_with_google_search,
348384
),
349385
)
350386

351387
return self._df[
352388
results["ml_generate_text_llm_result"].str.lower().str.contains("true")
353389
]
354390

355-
def map(self, instruction: str, output_column: str, model):
391+
def map(
392+
self,
393+
instruction: str,
394+
output_column: str,
395+
model,
396+
ground_with_google_search: bool = False,
397+
):
356398
"""
357399
Maps the DataFrame with the semantics of the user instruction.
358400
@@ -376,21 +418,29 @@ def map(self, instruction: str, output_column: str, model):
376418
[2 rows x 3 columns]
377419
378420
Args:
379-
instruction:
421+
instruction (str):
380422
An instruction on how to map the data. This value must contain
381423
column references by name, which should be wrapped in a pair of braces.
382424
For example, if you have a column "food", you can refer to this column
383425
in the instructions like:
384426
"Get the ingredients of {food}."
385427
386-
output_column:
428+
output_column (str):
387429
The column name of the mapping result.
388430
389-
model:
431+
model (bigframes.ml.llm.GeminiTextGenerator):
390432
A GeminiTextGenerator provided by Bigframes ML package.
391433
434+
ground_with_google_search (bool, default False):
435+
Enables Grounding with Google Search for the GeminiTextGenerator model.
436+
When set to True, the model incorporates relevant information from Google
437+
Search results into its responses, enhancing their accuracy and factualness.
438+
Note: Using this feature may impact billing costs. Refer to the pricing
439+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
440+
The default is `False`.
441+
392442
Returns:
393-
DataFrame with attached mapping results.
443+
bigframes.pandas.DataFrame: DataFrame with attached mapping results.
394444
395445
Raises:
396446
NotImplementedError: when the semantic operator experiment is off.
@@ -406,6 +456,12 @@ def map(self, instruction: str, output_column: str, model):
406456
if column not in self._df.columns:
407457
raise ValueError(f"Column {column} not found.")
408458

459+
if ground_with_google_search:
460+
warnings.warn(
461+
"Enables Grounding with Google Search may impact billing cost. See pricing "
462+
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
463+
)
464+
409465
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
410466
for column in columns:
411467
if df[column].dtype != dtypes.STRING_DTYPE:
@@ -421,14 +477,22 @@ def map(self, instruction: str, output_column: str, model):
421477
model.predict(
422478
self._make_prompt(df, columns, user_instruction, output_instruction),
423479
temperature=0.0,
480+
ground_with_google_search=ground_with_google_search,
424481
)["ml_generate_text_llm_result"],
425482
)
426483

427484
from bigframes.core.reshape import concat
428485

429486
return concat([self._df, results.rename(output_column)], axis=1)
430487

431-
def join(self, other, instruction: str, model, max_rows: int = 1000):
488+
def join(
489+
self,
490+
other,
491+
instruction: str,
492+
model,
493+
max_rows: int = 1000,
494+
ground_with_google_search: bool = False,
495+
):
432496
"""
433497
Joines two dataframes by applying the instruction over each pair of rows from
434498
the left and right table.
@@ -455,10 +519,10 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
455519
[4 rows x 2 columns]
456520
457521
Args:
458-
other:
522+
other (bigframes.pandas.DataFrame):
459523
The other dataframe.
460524
461-
instruction:
525+
instruction (str):
462526
An instruction on how left and right rows can be joined. This value must contain
463527
column references by name. which should be wrapped in a pair of braces.
464528
For example: "The {city} belongs to the {country}".
@@ -467,22 +531,36 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
467531
self joins. For example: "The {left.employee_name} reports to {right.employee_name}"
468532
For unique column names, this prefix is optional.
469533
470-
model:
534+
model (bigframes.ml.llm.GeminiTextGenerator):
471535
A GeminiTextGenerator provided by Bigframes ML package.
472536
473-
max_rows:
537+
max_rows (int, default 1000):
474538
The maximum number of rows allowed to be sent to the model per call. If the result is too large, the method
475539
call will end early with an error.
476540
541+
ground_with_google_search (bool, default False):
542+
Enables Grounding with Google Search for the GeminiTextGenerator model.
543+
When set to True, the model incorporates relevant information from Google
544+
Search results into its responses, enhancing their accuracy and factualness.
545+
Note: Using this feature may impact billing costs. Refer to the pricing
546+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
547+
The default is `False`.
548+
477549
Returns:
478-
The joined dataframe.
550+
bigframes.pandas.DataFrame: The joined dataframe.
479551
480552
Raises:
481553
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
482554
"""
483555
self._validate_model(model)
484556
columns = self._parse_columns(instruction)
485557

558+
if ground_with_google_search:
559+
warnings.warn(
560+
"Enables Grounding with Google Search may impact billing cost. See pricing "
561+
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
562+
)
563+
486564
joined_table_rows = len(self._df) * len(other)
487565

488566
if joined_table_rows > max_rows:
@@ -545,7 +623,9 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
545623

546624
joined_df = self._df.merge(other, how="cross", suffixes=("_left", "_right"))
547625

548-
return joined_df.semantics.filter(instruction, model).reset_index(drop=True)
626+
return joined_df.semantics.filter(
627+
instruction, model, ground_with_google_search=ground_with_google_search
628+
).reset_index(drop=True)
549629

550630
def search(
551631
self,
@@ -644,7 +724,13 @@ def search(
644724

645725
return typing.cast(bigframes.dataframe.DataFrame, search_result)
646726

647-
def top_k(self, instruction: str, model, k=10):
727+
def top_k(
728+
self,
729+
instruction: str,
730+
model,
731+
k: int = 10,
732+
ground_with_google_search: bool = False,
733+
):
648734
"""
649735
Ranks each tuple and returns the k best according to the instruction.
650736
@@ -682,6 +768,14 @@ def top_k(self, instruction: str, model, k=10):
682768
k (int, default 10):
683769
The number of rows to return.
684770
771+
ground_with_google_search (bool, default False):
772+
Enables Grounding with Google Search for the GeminiTextGenerator model.
773+
When set to True, the model incorporates relevant information from Google
774+
Search results into its responses, enhancing their accuracy and factualness.
775+
Note: Using this feature may impact billing costs. Refer to the pricing
776+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
777+
The default is `False`.
778+
685779
Returns:
686780
bigframes.dataframe.DataFrame: A new DataFrame with the top k rows.
687781
@@ -703,6 +797,12 @@ def top_k(self, instruction: str, model, k=10):
703797
"Semantic aggregations are limited to a single column."
704798
)
705799

800+
if ground_with_google_search:
801+
warnings.warn(
802+
"Enables Grounding with Google Search may impact billing cost. See pricing "
803+
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
804+
)
805+
706806
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
707807
column = columns[0]
708808
if df[column].dtype != dtypes.STRING_DTYPE:
@@ -743,6 +843,7 @@ def top_k(self, instruction: str, model, k=10):
743843
user_instruction,
744844
model,
745845
k - num_selected,
846+
ground_with_google_search,
746847
)
747848
num_selected += num_new_selected
748849

@@ -757,7 +858,13 @@ def top_k(self, instruction: str, model, k=10):
757858

758859
@staticmethod
759860
def _topk_partition(
760-
df, column: str, status_column: str, user_instruction: str, model, k
861+
df,
862+
column: str,
863+
status_column: str,
864+
user_instruction: str,
865+
model,
866+
k: int,
867+
ground_with_google_search: bool,
761868
):
762869
output_instruction = (
763870
"Given a question and two documents, choose the document that best answers "
@@ -784,7 +891,12 @@ def _topk_partition(
784891
import bigframes.dataframe
785892

786893
predict_df = typing.cast(
787-
bigframes.dataframe.DataFrame, model.predict(prompt_s, temperature=0.0)
894+
bigframes.dataframe.DataFrame,
895+
model.predict(
896+
prompt_s,
897+
temperature=0.0,
898+
ground_with_google_search=ground_with_google_search,
899+
),
788900
)
789901

790902
marks = predict_df["ml_generate_text_llm_result"].str.contains("2")

0 commit comments

Comments
 (0)