Skip to content

Commit 1156c1e

Browse files
authored
feat: support ML.GENERATE_EMBEDDING in PaLM2TextEmbeddingGenerator (#539)
* feat: support ML.GENERATE_EMBEDDING in PaLM2TextEmbeddingGenerator
1 parent 54e49cf commit 1156c1e

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

bigframes/ml/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ def generate_text(
152152
),
153153
)
154154

155-
def generate_text_embedding(
155+
def generate_embedding(
156156
self,
157157
input_data: bpd.DataFrame,
158158
options: Mapping[str, int | float],
159159
) -> bpd.DataFrame:
160160
return self._apply_sql(
161161
input_data,
162-
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text_embedding(
162+
lambda source_df: self._model_manipulation_sql_generator.ml_generate_embedding(
163163
source_df=source_df,
164164
struct_options=options,
165165
),

bigframes/ml/llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
_GEMINI_PRO_ENDPOINT = "gemini-pro"
4545

4646
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
47-
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
47+
_ML_EMBED_TEXT_STATUS = "ml_generate_embedding_status"
4848

4949

5050
@log_adapter.class_logger
@@ -389,7 +389,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
389389
"flatten_json_output": True,
390390
}
391391

392-
df = self._bqml_model.generate_text_embedding(X, options)
392+
df = self._bqml_model.generate_embedding(X, options)
393393

394394
if (df[_ML_EMBED_TEXT_STATUS] != "").any():
395395
warnings.warn(

bigframes/ml/sql.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ def ml_generate_text(
270270
return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL `{self._model_name}`,
271271
({self._source_sql(source_df)}), {struct_options_sql})"""
272272

273-
def ml_generate_text_embedding(
273+
def ml_generate_embedding(
274274
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
275275
) -> str:
276-
"""Encode ML.GENERATE_TEXT_EMBEDDING for BQML"""
276+
"""Encode ML.GENERATE_EMBEDDING for BQML"""
277277
struct_options_sql = self.struct_options(**struct_options)
278-
return f"""SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `{self._model_name}`,
278+
return f"""SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `{self._model_name}`,
279279
({self._source_sql(source_df)}), {struct_options_sql})"""
280280

281281
def ml_detect_anomalies(

tests/system/small/ml/test_llm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ def test_embedding_generator_predict_success(
261261
):
262262
df = palm2_embedding_generator_model.predict(llm_text_df).to_pandas()
263263
assert df.shape == (3, 4)
264-
assert "text_embedding" in df.columns
265-
series = df["text_embedding"]
264+
assert "ml_generate_embedding_result" in df.columns
265+
series = df["ml_generate_embedding_result"]
266266
value = series[0]
267267
assert len(value) == 768
268268

@@ -273,8 +273,8 @@ def test_embedding_generator_multilingual_predict_success(
273273
):
274274
df = palm2_embedding_generator_multilingual_model.predict(llm_text_df).to_pandas()
275275
assert df.shape == (3, 4)
276-
assert "text_embedding" in df.columns
277-
series = df["text_embedding"]
276+
assert "ml_generate_embedding_result" in df.columns
277+
series = df["ml_generate_embedding_result"]
278278
value = series[0]
279279
assert len(value) == 768
280280

@@ -285,8 +285,8 @@ def test_embedding_generator_predict_series_success(
285285
):
286286
df = palm2_embedding_generator_model.predict(llm_text_df["prompt"]).to_pandas()
287287
assert df.shape == (3, 4)
288-
assert "text_embedding" in df.columns
289-
series = df["text_embedding"]
288+
assert "ml_generate_embedding_result" in df.columns
289+
series = df["ml_generate_embedding_result"]
290290
value = series[0]
291291
assert len(value) == 768
292292

tests/unit/ml/test_sql.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -373,17 +373,17 @@ def test_ml_generate_text_correct(
373373
)
374374

375375

376-
def test_ml_generate_text_embedding_correct(
376+
def test_ml_generate_embedding_correct(
377377
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
378378
mock_df: bpd.DataFrame,
379379
):
380-
sql = model_manipulation_sql_generator.ml_generate_text_embedding(
380+
sql = model_manipulation_sql_generator.ml_generate_embedding(
381381
source_df=mock_df,
382382
struct_options={"option_key1": 1, "option_key2": 2.2},
383383
)
384384
assert (
385385
sql
386-
== """SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`,
386+
== """SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`,
387387
(input_X_sql), STRUCT(
388388
1 AS option_key1,
389389
2.2 AS option_key2))"""

0 commit comments

Comments
 (0)