Skip to content

Commit 0b3f8e5

Browse files
authored
feat: option to use bq connection without check (#460)
* feat: option to use bq connection without check * revert breaking signature change, centralize connection manager skipping * fix bad referencing * use public property from session * revert unintended test_iam_permissions change * fix couple of more unwanted changes
1 parent a5345fe commit 0b3f8e5

File tree

9 files changed

+164
-100
lines changed

9 files changed

+164
-100
lines changed

bigframes/_config/bigquery_options.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
use_regional_endpoints: bool = False,
4141
application_name: Optional[str] = None,
4242
kms_key_name: Optional[str] = None,
43+
skip_bq_connection_check: bool = False,
4344
):
4445
self._credentials = credentials
4546
self._project = project
@@ -48,6 +49,7 @@ def __init__(
4849
self._use_regional_endpoints = use_regional_endpoints
4950
self._application_name = application_name
5051
self._kms_key_name = kms_key_name
52+
self._skip_bq_connection_check = skip_bq_connection_check
5153
self._session_started = False
5254

5355
@property
@@ -105,14 +107,16 @@ def project(self, value: Optional[str]):
105107

106108
@property
107109
def bq_connection(self) -> Optional[str]:
108-
"""Name of the BigQuery connection to use. Should be of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
110+
"""Name of the BigQuery connection to use. Should be of the form
111+
<PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
109112
110113
You should either have the connection already created in the
111114
<code>location</code> you have chosen, or you should have the Project IAM
112115
Admin role to enable the service to create the connection for you if you
113116
need it.
114117
115-
If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection.
118+
If this option isn't provided, or project or location aren't provided,
119+
session will use its default project/location/connection_id as default connection.
116120
"""
117121
return self._bq_connection
118122

@@ -122,6 +126,26 @@ def bq_connection(self, value: Optional[str]):
122126
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="bq_connection"))
123127
self._bq_connection = value
124128

129+
@property
130+
def skip_bq_connection_check(self) -> bool:
131+
"""Forcibly use the BigQuery connection.
132+
133+
Setting this flag to True would avoid creating the BigQuery connection
134+
and checking or setting IAM permissions on it. So if the BigQuery
135+
connection (default or user-provided) does not exist, or it does not have
136+
necessary permissions set up to support BigQuery DataFrames operations,
137+
then a runtime error will be reported.
138+
"""
139+
return self._skip_bq_connection_check
140+
141+
@skip_bq_connection_check.setter
142+
def skip_bq_connection_check(self, value: bool):
143+
if self._session_started and self._skip_bq_connection_check != value:
144+
raise ValueError(
145+
SESSION_STARTED_MESSAGE.format(attribute="skip_bq_connection_check")
146+
)
147+
self._skip_bq_connection_check = value
148+
125149
@property
126150
def use_regional_endpoints(self) -> bool:
127151
"""Flag to connect to regional API endpoints.

bigframes/clients.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30+
def resolve_full_bq_connection_name(
31+
connection_name: str, default_project: str, default_location: str
32+
) -> str:
33+
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
34+
Use default project, location or connection_id when any of them are missing."""
35+
if connection_name.count(".") == 2:
36+
return connection_name
37+
38+
if connection_name.count(".") == 1:
39+
return f"{default_project}.{connection_name}"
40+
41+
if connection_name.count(".") == 0:
42+
return f"{default_project}.{default_location}.{connection_name}"
43+
44+
raise ValueError(f"Invalid connection name format: {connection_name}.")
45+
46+
3047
class BqConnectionManager:
3148
"""Manager to handle operations with BQ connections."""
3249

@@ -41,23 +58,6 @@ def __init__(
4158
self._bq_connection_client = bq_connection_client
4259
self._cloud_resource_manager_client = cloud_resource_manager_client
4360

44-
@classmethod
45-
def resolve_full_connection_name(
46-
cls, connection_name: str, default_project: str, default_location: str
47-
) -> str:
48-
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
49-
Use default project, location or connection_id when any of them are missing."""
50-
if connection_name.count(".") == 2:
51-
return connection_name
52-
53-
if connection_name.count(".") == 1:
54-
return f"{default_project}.{connection_name}"
55-
56-
if connection_name.count(".") == 0:
57-
return f"{default_project}.{default_location}.{connection_name}"
58-
59-
raise ValueError(f"Invalid connection name format: {connection_name}.")
60-
6161
def create_bq_connection(
6262
self, project_id: str, location: str, connection_id: str, iam_role: str
6363
):
@@ -73,12 +73,6 @@ def create_bq_connection(
7373
iam_role:
7474
str of the IAM role that the service account of the created connection needs to aquire. E.g. 'run.invoker', 'aiplatform.user'
7575
"""
76-
# TODO(shobs): The below command to enable BigQuery Connection API needs
77-
# to be automated. Disabling for now since most target users would not
78-
# have the privilege to enable API in a project.
79-
# log("Making sure BigQuery Connection API is enabled")
80-
# if os.system("gcloud services enable bigqueryconnection.googleapis.com"):
81-
# raise ValueError("Failed to enable BigQuery Connection API")
8276
# If the intended connection does not exist then create it
8377
service_account_id = self._get_service_account_if_connection_exists(
8478
project_id, location, connection_id

bigframes/functions/remote_function.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,8 @@ def __init__(
126126
bq_location,
127127
bq_dataset,
128128
bq_client,
129-
bq_connection_client,
130129
bq_connection_id,
131-
cloud_resource_manager_client,
130+
bq_connection_manager,
132131
cloud_function_service_account,
133132
cloud_function_kms_key_name,
134133
cloud_function_docker_repository,
@@ -140,9 +139,7 @@ def __init__(
140139
self._bq_dataset = bq_dataset
141140
self._bq_client = bq_client
142141
self._bq_connection_id = bq_connection_id
143-
self._bq_connection_manager = clients.BqConnectionManager(
144-
bq_connection_client, cloud_resource_manager_client
145-
)
142+
self._bq_connection_manager = bq_connection_manager
146143
self._cloud_function_service_account = cloud_function_service_account
147144
self._cloud_function_kms_key_name = cloud_function_kms_key_name
148145
self._cloud_function_docker_repository = cloud_function_docker_repository
@@ -152,12 +149,13 @@ def create_bq_remote_function(
152149
):
153150
"""Create a BigQuery remote function given the artifacts of a user defined
154151
function and the http endpoint of a corresponding cloud function."""
155-
self._bq_connection_manager.create_bq_connection(
156-
self._gcp_project_id,
157-
self._bq_location,
158-
self._bq_connection_id,
159-
"run.invoker",
160-
)
152+
if self._bq_connection_manager:
153+
self._bq_connection_manager.create_bq_connection(
154+
self._gcp_project_id,
155+
self._bq_location,
156+
self._bq_connection_id,
157+
"run.invoker",
158+
)
161159

162160
# Create BQ function
163161
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
@@ -784,7 +782,7 @@ def remote_function(
784782
if not bigquery_connection:
785783
bigquery_connection = session._bq_connection # type: ignore
786784

787-
bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
785+
bigquery_connection = clients.resolve_full_bq_connection_name(
788786
bigquery_connection,
789787
default_project=dataset_ref.project,
790788
default_location=bq_location,
@@ -816,6 +814,8 @@ def remote_function(
816814
" For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin"
817815
)
818816

817+
bq_connection_manager = None if session is None else session.bqconnectionmanager
818+
819819
def wrapper(f):
820820
if not callable(f):
821821
raise TypeError("f must be callable, got {}".format(f))
@@ -832,9 +832,8 @@ def wrapper(f):
832832
bq_location,
833833
dataset_ref.dataset_id,
834834
bigquery_client,
835-
bigquery_connection_client,
836835
bq_connection_id,
837-
resource_manager_client,
836+
bq_connection_manager,
838837
cloud_function_service_account,
839838
cloud_function_kms_key_name,
840839
cloud_function_docker_repository,

bigframes/ml/llm.py

+42-42
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,10 @@ def __init__(
7373
):
7474
self.model_name = model_name
7575
self.session = session or bpd.get_global_session()
76-
self._bq_connection_manager = clients.BqConnectionManager(
77-
self.session.bqconnectionclient, self.session.resourcemanagerclient
78-
)
76+
self._bq_connection_manager = self.session.bqconnectionmanager
7977

8078
connection_name = connection_name or self.session._bq_connection
81-
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
79+
self.connection_name = clients.resolve_full_bq_connection_name(
8280
connection_name,
8381
default_project=self.session._project,
8482
default_location=self.session._location,
@@ -93,17 +91,19 @@ def _create_bqml_model(self):
9391
raise ValueError(
9492
"Must provide connection_name, either in constructor or through session options."
9593
)
96-
connection_name_parts = self.connection_name.split(".")
97-
if len(connection_name_parts) != 3:
98-
raise ValueError(
99-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
94+
95+
if self._bq_connection_manager:
96+
connection_name_parts = self.connection_name.split(".")
97+
if len(connection_name_parts) != 3:
98+
raise ValueError(
99+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
100+
)
101+
self._bq_connection_manager.create_bq_connection(
102+
project_id=connection_name_parts[0],
103+
location=connection_name_parts[1],
104+
connection_id=connection_name_parts[2],
105+
iam_role="aiplatform.user",
100106
)
101-
self._bq_connection_manager.create_bq_connection(
102-
project_id=connection_name_parts[0],
103-
location=connection_name_parts[1],
104-
connection_id=connection_name_parts[2],
105-
iam_role="aiplatform.user",
106-
)
107107

108108
if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
109109
raise ValueError(
@@ -289,12 +289,10 @@ def __init__(
289289
self.model_name = model_name
290290
self.version = version
291291
self.session = session or bpd.get_global_session()
292-
self._bq_connection_manager = clients.BqConnectionManager(
293-
self.session.bqconnectionclient, self.session.resourcemanagerclient
294-
)
292+
self._bq_connection_manager = self.session.bqconnectionmanager
295293

296294
connection_name = connection_name or self.session._bq_connection
297-
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
295+
self.connection_name = clients.resolve_full_bq_connection_name(
298296
connection_name,
299297
default_project=self.session._project,
300298
default_location=self.session._location,
@@ -309,17 +307,19 @@ def _create_bqml_model(self):
309307
raise ValueError(
310308
"Must provide connection_name, either in constructor or through session options."
311309
)
312-
connection_name_parts = self.connection_name.split(".")
313-
if len(connection_name_parts) != 3:
314-
raise ValueError(
315-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
310+
311+
if self._bq_connection_manager:
312+
connection_name_parts = self.connection_name.split(".")
313+
if len(connection_name_parts) != 3:
314+
raise ValueError(
315+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
316+
)
317+
self._bq_connection_manager.create_bq_connection(
318+
project_id=connection_name_parts[0],
319+
location=connection_name_parts[1],
320+
connection_id=connection_name_parts[2],
321+
iam_role="aiplatform.user",
316322
)
317-
self._bq_connection_manager.create_bq_connection(
318-
project_id=connection_name_parts[0],
319-
location=connection_name_parts[1],
320-
connection_id=connection_name_parts[2],
321-
iam_role="aiplatform.user",
322-
)
323323

324324
if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
325325
raise ValueError(
@@ -437,12 +437,10 @@ def __init__(
437437
connection_name: Optional[str] = None,
438438
):
439439
self.session = session or bpd.get_global_session()
440-
self._bq_connection_manager = clients.BqConnectionManager(
441-
self.session.bqconnectionclient, self.session.resourcemanagerclient
442-
)
440+
self._bq_connection_manager = self.session.bqconnectionmanager
443441

444442
connection_name = connection_name or self.session._bq_connection
445-
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
443+
self.connection_name = clients.resolve_full_bq_connection_name(
446444
connection_name,
447445
default_project=self.session._project,
448446
default_location=self.session._location,
@@ -457,17 +455,19 @@ def _create_bqml_model(self):
457455
raise ValueError(
458456
"Must provide connection_name, either in constructor or through session options."
459457
)
460-
connection_name_parts = self.connection_name.split(".")
461-
if len(connection_name_parts) != 3:
462-
raise ValueError(
463-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
458+
459+
if self._bq_connection_manager:
460+
connection_name_parts = self.connection_name.split(".")
461+
if len(connection_name_parts) != 3:
462+
raise ValueError(
463+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
464+
)
465+
self._bq_connection_manager.create_bq_connection(
466+
project_id=connection_name_parts[0],
467+
location=connection_name_parts[1],
468+
connection_id=connection_name_parts[2],
469+
iam_role="aiplatform.user",
464470
)
465-
self._bq_connection_manager.create_bq_connection(
466-
project_id=connection_name_parts[0],
467-
location=connection_name_parts[1],
468-
connection_id=connection_name_parts[2],
469-
iam_role="aiplatform.user",
470-
)
471471

472472
options = {"endpoint": _GEMINI_PRO_ENDPOINT}
473473

bigframes/ml/remote.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,9 @@ def __init__(
6262
self.output = output
6363
self.session = session or bpd.get_global_session()
6464

65-
self._bq_connection_manager = clients.BqConnectionManager(
66-
self.session.bqconnectionclient, self.session.resourcemanagerclient
67-
)
65+
self._bq_connection_manager = self.session.bqconnectionmanager
6866
connection_name = connection_name or self.session._bq_connection
69-
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
67+
self.connection_name = clients.resolve_full_bq_connection_name(
7068
connection_name,
7169
default_project=self.session._project,
7270
default_location=self.session._location,
@@ -81,17 +79,19 @@ def _create_bqml_model(self):
8179
raise ValueError(
8280
"Must provide connection_name, either in constructor or through session options."
8381
)
84-
connection_name_parts = self.connection_name.split(".")
85-
if len(connection_name_parts) != 3:
86-
raise ValueError(
87-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
82+
83+
if self._bq_connection_manager:
84+
connection_name_parts = self.connection_name.split(".")
85+
if len(connection_name_parts) != 3:
86+
raise ValueError(
87+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
88+
)
89+
self._bq_connection_manager.create_bq_connection(
90+
project_id=connection_name_parts[0],
91+
location=connection_name_parts[1],
92+
connection_id=connection_name_parts[2],
93+
iam_role="aiplatform.user",
8894
)
89-
self._bq_connection_manager.create_bq_connection(
90-
project_id=connection_name_parts[0],
91-
location=connection_name_parts[1],
92-
connection_id=connection_name_parts[2],
93-
iam_role="aiplatform.user",
94-
)
9595

9696
options = {
9797
"endpoint": self.endpoint,

0 commit comments

Comments
 (0)