Skip to content

Commit de1e0a4

Browse files
authored
feat: add ml.llm.GeminiTextGenerator model (#370)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 99a9e6e commit de1e0a4

File tree

6 files changed

+223
-0
lines changed

6 files changed

+223
-0
lines changed

README.rst

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ Create estimators for linear models by using the `bigframes.ml.linear_model modu
194194

195195
Create estimators for LLMs by using the `bigframes.ml.llm module <https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm>`_.
196196

197+
* Use the `GeminiTextGenerator class <https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.GeminiTextGenerator>`_ to create Gemini text generator models. Use these models
198+
for text generation tasks.
197199
* Use the `PaLM2TextGenerator class <https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.PaLM2TextGenerator>`_ to create PaLM2 text generator models. Use these models
198200
for text generation tasks.
199201
* Use the `PaLM2TextEmbeddingGenerator class <https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.PaLM2TextEmbeddingGenerator>`_ to create PaLM2 text embedding generator models.

bigframes/ml/llm.py

+174
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT,
4242
)
4343

44+
_GEMINI_PRO_ENDPOINT = "gemini-pro"
45+
4446
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
4547
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
4648

@@ -396,3 +398,175 @@ def to_gbq(
396398

397399
new_model = self._bqml_model.copy(model_name, replace)
398400
return new_model.session.read_gbq_model(model_name)
401+
402+
403+
@log_adapter.class_logger
404+
class GeminiTextGenerator(base.Predictor):
405+
"""Gemini text generator LLM model.
406+
407+
Args:
408+
session (bigframes.Session or None):
409+
BQ session to create the model. If None, use the global default session.
410+
connection_name (str or None):
411+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
412+
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
413+
permission if the connection isn't fully setup.
414+
"""
415+
416+
def __init__(
417+
self,
418+
session: Optional[bigframes.Session] = None,
419+
connection_name: Optional[str] = None,
420+
):
421+
self.session = session or bpd.get_global_session()
422+
self._bq_connection_manager = clients.BqConnectionManager(
423+
self.session.bqconnectionclient, self.session.resourcemanagerclient
424+
)
425+
426+
connection_name = connection_name or self.session._bq_connection
427+
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
428+
connection_name,
429+
default_project=self.session._project,
430+
default_location=self.session._location,
431+
)
432+
433+
self._bqml_model_factory = globals.bqml_model_factory()
434+
self._bqml_model: core.BqmlModel = self._create_bqml_model()
435+
436+
def _create_bqml_model(self):
437+
# Parse and create connection if needed.
438+
if not self.connection_name:
439+
raise ValueError(
440+
"Must provide connection_name, either in constructor or through session options."
441+
)
442+
connection_name_parts = self.connection_name.split(".")
443+
if len(connection_name_parts) != 3:
444+
raise ValueError(
445+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
446+
)
447+
self._bq_connection_manager.create_bq_connection(
448+
project_id=connection_name_parts[0],
449+
location=connection_name_parts[1],
450+
connection_id=connection_name_parts[2],
451+
iam_role="aiplatform.user",
452+
)
453+
454+
options = {"endpoint": _GEMINI_PRO_ENDPOINT}
455+
456+
return self._bqml_model_factory.create_remote_model(
457+
session=self.session, connection_name=self.connection_name, options=options
458+
)
459+
460+
@classmethod
461+
def _from_bq(
462+
cls, session: bigframes.Session, model: bigquery.Model
463+
) -> GeminiTextGenerator:
464+
assert model.model_type == "MODEL_TYPE_UNSPECIFIED"
465+
assert "remoteModelInfo" in model._properties
466+
assert "connection" in model._properties["remoteModelInfo"]
467+
468+
# Parse the remote model endpoint
469+
model_connection = model._properties["remoteModelInfo"]["connection"]
470+
471+
text_generator_model = cls(session=session, connection_name=model_connection)
472+
text_generator_model._bqml_model = core.BqmlModel(session, model)
473+
return text_generator_model
474+
475+
def predict(
476+
self,
477+
X: Union[bpd.DataFrame, bpd.Series],
478+
temperature: float = 0.9,
479+
max_output_tokens: int = 8192,
480+
top_k: int = 40,
481+
top_p: float = 1.0,
482+
) -> bpd.DataFrame:
483+
"""Predict the result from input DataFrame.
484+
485+
Args:
486+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
487+
Input DataFrame or Series, which contains only one column of prompts.
488+
Prompts can include preamble, questions, suggestions, instructions, or examples.
489+
490+
temperature (float, default 0.9):
491+
The temperature is used for sampling during the response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a more deterministic and less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 is deterministic: the highest probability response is always selected.
492+
Default 0.9. Possible values [0.0, 1.0].
493+
494+
max_output_tokens (int, default 8192):
495+
Maximum number of tokens that can be generated in the response. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
496+
Specify a lower value for shorter responses and a higher value for potentially longer responses.
497+
Default 8192. Possible values are in the range [1, 8192].
498+
499+
top_k (int, default 40):
500+
Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-K of 3 means that the next token is selected from among the three most probable tokens by using temperature.
501+
For each token selection step, the top-K tokens with the highest probabilities are sampled. Then tokens are further filtered based on top-P with the final token selected using temperature sampling.
502+
Specify a lower value for less random responses and a higher value for more random responses.
503+
Default 40. Possible values [1, 40].
504+
505+
top_p (float, default 0.95)::
506+
Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.
507+
Specify a lower value for less random responses and a higher value for more random responses.
508+
Default 1.0. Possible values [0.0, 1.0].
509+
510+
511+
Returns:
512+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
513+
"""
514+
515+
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
516+
if temperature < 0.0 or temperature > 1.0:
517+
raise ValueError(f"temperature must be [0.0, 1.0], but is {temperature}.")
518+
519+
if max_output_tokens not in range(1, 8193):
520+
raise ValueError(
521+
f"max_output_token must be [1, 8192] for Gemini model, but is {max_output_tokens}."
522+
)
523+
524+
if top_k not in range(1, 41):
525+
raise ValueError(f"top_k must be [1, 40], but is {top_k}.")
526+
527+
if top_p < 0.0 or top_p > 1.0:
528+
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")
529+
530+
(X,) = utils.convert_to_dataframe(X)
531+
532+
if len(X.columns) != 1:
533+
raise ValueError(
534+
f"Only support one column as input. {constants.FEEDBACK_LINK}"
535+
)
536+
537+
# BQML identified the column by name
538+
col_label = cast(blocks.Label, X.columns[0])
539+
X = X.rename(columns={col_label: "prompt"})
540+
541+
options = {
542+
"temperature": temperature,
543+
"max_output_tokens": max_output_tokens,
544+
"top_k": top_k,
545+
"top_p": top_p,
546+
"flatten_json_output": True,
547+
}
548+
549+
df = self._bqml_model.generate_text(X, options)
550+
551+
if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
552+
warnings.warn(
553+
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
554+
RuntimeWarning,
555+
)
556+
557+
return df
558+
559+
def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
560+
"""Save the model to BigQuery.
561+
562+
Args:
563+
model_name (str):
564+
the name of the model.
565+
replace (bool, default False):
566+
whether to replace if the model already exists. Default to False.
567+
568+
Returns:
569+
GeminiTextGenerator: saved model."""
570+
571+
new_model = self._bqml_model.copy(model_name, replace)
572+
return new_model.session.read_gbq_model(model_name)

bigframes/ml/loader.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
llm._TEXT_GENERATOR_BISON_32K_ENDPOINT: llm.PaLM2TextGenerator,
5656
llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
5757
llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
58+
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
5859
}
5960
)
6061

docs/templates/toc.yml

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
- items:
106106
- name: Overview
107107
uid: bigframes.ml.llm
108+
- name: GeminiTextGenerator
109+
uid: bigframes.ml.llm.GeminiTextGenerator
108110
- name: PaLM2TextGenerator
109111
uid: bigframes.ml.llm.PaLM2TextGenerator
110112
- name: PaLM2TextEmbeddingGenerator

tests/system/small/ml/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def palm2_embedding_generator_multilingual_model(
267267
)
268268

269269

270+
@pytest.fixture(scope="session")
271+
def gemini_text_generator_model(session, bq_connection) -> llm.GeminiTextGenerator:
272+
return llm.GeminiTextGenerator(session=session, connection_name=bq_connection)
273+
274+
270275
@pytest.fixture(scope="session")
271276
def linear_remote_model_params() -> dict:
272277
# Pre-deployed endpoint of linear reg model in Vertex.

tests/system/small/ml/test_llm.py

+39
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,42 @@ def test_embedding_generator_predict_series_success(
272272
series = df["text_embedding"]
273273
value = series[0]
274274
assert len(value) == 768
275+
276+
277+
def test_create_gemini_text_generator_model(
278+
gemini_text_generator_model, dataset_id, bq_connection
279+
):
280+
# Model creation doesn't return error
281+
assert gemini_text_generator_model is not None
282+
assert gemini_text_generator_model._bqml_model is not None
283+
284+
# save, load to ensure configuration was kept
285+
reloaded_model = gemini_text_generator_model.to_gbq(
286+
f"{dataset_id}.temp_text_model", replace=True
287+
)
288+
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
289+
assert reloaded_model.connection_name == bq_connection
290+
291+
292+
@pytest.mark.flaky(retries=2, delay=120)
293+
def test_gemini_text_generator_predict_default_params_success(
294+
gemini_text_generator_model, llm_text_df
295+
):
296+
df = gemini_text_generator_model.predict(llm_text_df).to_pandas()
297+
assert df.shape == (3, 4)
298+
assert "ml_generate_text_llm_result" in df.columns
299+
series = df["ml_generate_text_llm_result"]
300+
assert all(series.str.len() > 20)
301+
302+
303+
@pytest.mark.flaky(retries=2, delay=120)
304+
def test_gemini_text_generator_predict_with_params_success(
305+
gemini_text_generator_model, llm_text_df
306+
):
307+
df = gemini_text_generator_model.predict(
308+
llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5
309+
).to_pandas()
310+
assert df.shape == (3, 4)
311+
assert "ml_generate_text_llm_result" in df.columns
312+
series = df["ml_generate_text_llm_result"]
313+
assert all(series.str.len() > 20)

0 commit comments

Comments
 (0)