Skip to content

Commit 0bfc4fb

Browse files
authored
feat: add remote vertex model support (#237)
b/299356085
1 parent d0d9b84 commit 0bfc4fb

File tree

11 files changed

+319
-4
lines changed

11 files changed

+319
-4
lines changed

bigframes/ml/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
"llm",
2727
"forecasting",
2828
"imported",
29+
"remote",
2930
]

bigframes/ml/core.py

+8
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,19 @@ def create_remote_model(
294294
self,
295295
session: bigframes.Session,
296296
connection_name: str,
297+
input: Mapping[str, str] = {},
298+
output: Mapping[str, str] = {},
297299
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
298300
) -> BqmlModel:
299301
"""Create a session-temporary BQML remote model with the CREATE OR REPLACE MODEL statement
300302
301303
Args:
302304
connection_name:
303305
a BQ connection to talk with Vertex AI, of the format <PROJECT_NUMBER>.<REGION>.<CONNECTION_NAME>. https://cloud.google.com/bigquery/docs/create-cloud-resource-connection
306+
input:
307+
input schema for general remote models
308+
output:
309+
output schema for general remote models
304310
options:
305311
a dict of options to configure the model. Generates a BQML OPTIONS clause
306312
@@ -311,6 +317,8 @@ def create_remote_model(
311317
sql = self._model_creation_sql_generator.create_remote_model(
312318
connection_name=connection_name,
313319
model_ref=model_ref,
320+
input=input,
321+
output=output,
314322
options=options,
315323
)
316324

bigframes/ml/remote.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BigFrames general remote models."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Mapping, Optional, Union
20+
import warnings
21+
22+
import bigframes
23+
from bigframes import clients
24+
from bigframes.core import log_adapter
25+
from bigframes.ml import base, core, globals, utils
26+
import bigframes.pandas as bpd
27+
28+
_SUPPORTED_DTYPES = (
29+
"bool",
30+
"string",
31+
"int64",
32+
"float64",
33+
"array<bool>",
34+
"array<string>",
35+
"array<int64>",
36+
"array<float64>",
37+
)
38+
39+
_REMOTE_MODEL_STATUS = "remote_model_status"
40+
41+
42+
@log_adapter.class_logger
43+
class VertexAIModel(base.BaseEstimator):
44+
"""Remote model from a Vertex AI https endpoint. User must specify https endpoint, input schema and output schema.
45+
How to deploy a model in Vertex AI https://cloud.google.com/bigquery/docs/bigquery-ml-remote-model-tutorial#Deploy-Model-on-Vertex-AI.
46+
47+
Args:
48+
endpoint (str):
49+
Vertex AI https endpoint.
50+
input ({column_name: column_type}):
51+
Input schema. Supported types are "bool", "string", "int64", "float64", "array<bool>", "array<string>", "array<int64>", "array<float64>".
52+
output ({column_name: column_type}):
53+
Output label schema. Supported the same types as the input.
54+
session (bigframes.Session or None):
55+
BQ session to create the model. If None, use the global default session.
56+
connection_name (str or None):
57+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
58+
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
59+
permission if the connection isn't fully setup.
60+
"""
61+
62+
def __init__(
63+
self,
64+
endpoint: str,
65+
input: Mapping[str, str],
66+
output: Mapping[str, str],
67+
session: Optional[bigframes.Session] = None,
68+
connection_name: Optional[str] = None,
69+
):
70+
self.endpoint = endpoint
71+
self.input = input
72+
self.output = output
73+
self.session = session or bpd.get_global_session()
74+
75+
self._bq_connection_manager = clients.BqConnectionManager(
76+
self.session.bqconnectionclient, self.session.resourcemanagerclient
77+
)
78+
connection_name = connection_name or self.session._bq_connection
79+
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
80+
connection_name,
81+
default_project=self.session._project,
82+
default_location=self.session._location,
83+
)
84+
85+
self._bqml_model_factory = globals.bqml_model_factory()
86+
self._bqml_model: core.BqmlModel = self._create_bqml_model()
87+
88+
def _create_bqml_model(self):
89+
# Parse and create connection if needed.
90+
if not self.connection_name:
91+
raise ValueError(
92+
"Must provide connection_name, either in constructor or through session options."
93+
)
94+
connection_name_parts = self.connection_name.split(".")
95+
if len(connection_name_parts) != 3:
96+
raise ValueError(
97+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
98+
)
99+
self._bq_connection_manager.create_bq_connection(
100+
project_id=connection_name_parts[0],
101+
location=connection_name_parts[1],
102+
connection_id=connection_name_parts[2],
103+
iam_role="aiplatform.user",
104+
)
105+
106+
options = {
107+
"endpoint": self.endpoint,
108+
}
109+
110+
def standardize_type(v: str):
111+
v = v.lower()
112+
v = v.replace("boolean", "bool")
113+
114+
if v not in _SUPPORTED_DTYPES:
115+
raise ValueError(
116+
f"Data type {v} is not supported. We only support {', '.join(_SUPPORTED_DTYPES)}."
117+
)
118+
119+
return v
120+
121+
self.input = {k: standardize_type(v) for k, v in self.input.items()}
122+
self.output = {k: standardize_type(v) for k, v in self.output.items()}
123+
124+
return self._bqml_model_factory.create_remote_model(
125+
session=self.session,
126+
connection_name=self.connection_name,
127+
input=self.input,
128+
output=self.output,
129+
options=options,
130+
)
131+
132+
def predict(
133+
self,
134+
X: Union[bpd.DataFrame, bpd.Series],
135+
) -> bpd.DataFrame:
136+
"""Predict the result from the input DataFrame.
137+
138+
Args:
139+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
140+
Input DataFrame or Series, which needs to comply with the input parameter of the model.
141+
142+
Returns:
143+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
144+
"""
145+
146+
(X,) = utils.convert_to_dataframe(X)
147+
148+
df = self._bqml_model.predict(X)
149+
150+
# unlike LLM models, the general remote model status is null for successful runs.
151+
if (df[_REMOTE_MODEL_STATUS].notna()).any():
152+
warnings.warn(
153+
f"Some predictions failed. Check column {_REMOTE_MODEL_STATUS} for detailed status. You may want to filter the failed rows and retry.",
154+
RuntimeWarning,
155+
)
156+
157+
return df

bigframes/ml/sql.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def build_expressions(self, *expr_sqls: str) -> str:
5757
indent_str = " "
5858
return "\n" + indent_str + f",\n{indent_str}".join(expr_sqls)
5959

60+
def build_schema(self, **kwargs: str) -> str:
61+
"""Encode a dict of values into a formatted schema type items for SQL"""
62+
indent_str = " "
63+
param_strs = [f"{k} {v}" for k, v in kwargs.items()]
64+
return "\n" + indent_str + f",\n{indent_str}".join(param_strs)
65+
6066
def options(self, **kwargs: Union[str, int, float, Iterable[str]]) -> str:
6167
"""Encode the OPTIONS clause for BQML"""
6268
return f"OPTIONS({self.build_parameters(**kwargs)})"
@@ -65,6 +71,14 @@ def struct_options(self, **kwargs: Union[int, float]) -> str:
6571
"""Encode a BQ STRUCT as options."""
6672
return f"STRUCT({self.build_structs(**kwargs)})"
6773

74+
def input(self, **kwargs: str) -> str:
75+
"""Encode a BQML INPUT clause."""
76+
return f"INPUT({self.build_schema(**kwargs)})"
77+
78+
def output(self, **kwargs: str) -> str:
79+
"""Encode a BQML OUTPUT clause."""
80+
return f"OUTPUT({self.build_schema(**kwargs)})"
81+
6882
# Connection
6983
def connection(self, conn_name: str) -> str:
7084
"""Encode the REMOTE WITH CONNECTION clause for BQML. conn_name is of the format <PROJECT_NUMBER/PROJECT_ID>.<REGION>.<CONNECTION_NAME>."""
@@ -154,15 +168,19 @@ def create_remote_model(
154168
self,
155169
connection_name: str,
156170
model_ref: google.cloud.bigquery.ModelReference,
171+
input: Mapping[str, str] = {},
172+
output: Mapping[str, str] = {},
157173
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
158174
) -> str:
159175
"""Encode the CREATE OR REPLACE MODEL statement for BQML remote model."""
160-
options_sql = self.options(**options)
161-
162176
parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"]
177+
if input:
178+
parts.append(self.input(**input))
179+
if output:
180+
parts.append(self.output(**output))
163181
parts.append(self.connection(connection_name))
164-
if options_sql:
165-
parts.append(options_sql)
182+
if options:
183+
parts.append(self.options(**options))
166184
return "\n".join(parts)
167185

168186
def create_imported_model(

docs/reference/bigframes.ml/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ API Reference
3030
pipeline
3131

3232
preprocessing
33+
34+
remote
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
bigframes.ml.remote
2+
===================
3+
4+
.. automodule:: bigframes.ml.remote
5+
:members:
6+
:inherited-members:
7+
:undoc-members:

docs/templates/toc.yml

+6
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@
108108
- name: PaLM2TextEmbeddingGenerator
109109
uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator
110110
name: llm
111+
- items:
112+
- name: Overview
113+
uid: bigframes.ml.remote
114+
- name: VertexAIModel
115+
uid: bigframes.ml.remote.VertexAIModel
116+
name: remote
111117
- items:
112118
- name: metrics
113119
uid: bigframes.ml.metrics

tests/system/small/ml/conftest.py

+41
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
imported,
3030
linear_model,
3131
llm,
32+
remote,
3233
)
3334

3435

@@ -247,6 +248,46 @@ def palm2_embedding_generator_multilingual_model(
247248
)
248249

249250

251+
@pytest.fixture(scope="session")
252+
def linear_remote_model_params() -> dict:
253+
# Pre-deployed endpoint of linear reg model in Vertex.
254+
# bigframes-test-linreg2 -> bigframes-test-linreg-endpoint2
255+
return {
256+
"input": {"culmen_length_mm": "float64"},
257+
"output": {"predicted_body_mass_g": "array<float64>"},
258+
"endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/1084210331973/locations/us-central1/endpoints/3193318217619603456",
259+
}
260+
261+
262+
@pytest.fixture(scope="session")
263+
def bqml_linear_remote_model(
264+
session, bq_connection, linear_remote_model_params
265+
) -> core.BqmlModel:
266+
options = {
267+
"endpoint": linear_remote_model_params["endpoint"],
268+
}
269+
return globals.bqml_model_factory().create_remote_model(
270+
session=session,
271+
input=linear_remote_model_params["input"],
272+
output=linear_remote_model_params["output"],
273+
connection_name=bq_connection,
274+
options=options,
275+
)
276+
277+
278+
@pytest.fixture(scope="session")
279+
def linear_remote_vertex_model(
280+
session, bq_connection, linear_remote_model_params
281+
) -> remote.VertexAIModel:
282+
return remote.VertexAIModel(
283+
endpoint=linear_remote_model_params["endpoint"],
284+
input=linear_remote_model_params["input"],
285+
output=linear_remote_model_params["output"],
286+
session=session,
287+
connection_name=bq_connection,
288+
)
289+
290+
250291
@pytest.fixture(scope="session")
251292
def time_series_bqml_arima_plus_model(
252293
session, time_series_arima_plus_model_name

tests/system/small/ml/test_core.py

+16
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,22 @@ def test_model_predict_with_unnamed_index(
289289
)
290290

291291

292+
def test_remote_model_predict(
293+
bqml_linear_remote_model: core.BqmlModel, new_penguins_df
294+
):
295+
predictions = bqml_linear_remote_model.predict(new_penguins_df).to_pandas()
296+
expected = pd.DataFrame(
297+
{"predicted_body_mass_g": [[3739.54], [3675.79], [3619.54]]},
298+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
299+
)
300+
pd.testing.assert_frame_equal(
301+
predictions[["predicted_body_mass_g"]].sort_index(),
302+
expected,
303+
check_exact=False,
304+
rtol=0.1,
305+
)
306+
307+
292308
@pytest.mark.flaky(retries=2, delay=120)
293309
def test_model_generate_text(
294310
bqml_palm2_text_generator_model: core.BqmlModel, llm_text_df

tests/system/small/ml/test_remote.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
17+
from bigframes.ml import remote
18+
19+
20+
def test_remote_linear_vertex_model_predict(
21+
linear_remote_vertex_model: remote.VertexAIModel, new_penguins_df
22+
):
23+
predictions = linear_remote_vertex_model.predict(new_penguins_df).to_pandas()
24+
expected = pd.DataFrame(
25+
{"predicted_body_mass_g": [[3739.54], [3675.79], [3619.54]]},
26+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
27+
)
28+
pd.testing.assert_frame_equal(
29+
predictions[["predicted_body_mass_g"]].sort_index(),
30+
expected,
31+
check_exact=False,
32+
rtol=0.1,
33+
)

0 commit comments

Comments
 (0)