@@ -73,12 +73,10 @@ def __init__(
73
73
):
74
74
self .model_name = model_name
75
75
self .session = session or bpd .get_global_session ()
76
- self ._bq_connection_manager = clients .BqConnectionManager (
77
- self .session .bqconnectionclient , self .session .resourcemanagerclient
78
- )
76
+ self ._bq_connection_manager = self .session .bqconnectionmanager
79
77
80
78
connection_name = connection_name or self .session ._bq_connection
81
- self .connection_name = self . _bq_connection_manager . resolve_full_connection_name (
79
+ self .connection_name = clients . resolve_full_bq_connection_name (
82
80
connection_name ,
83
81
default_project = self .session ._project ,
84
82
default_location = self .session ._location ,
@@ -93,17 +91,19 @@ def _create_bqml_model(self):
93
91
raise ValueError (
94
92
"Must provide connection_name, either in constructor or through session options."
95
93
)
96
- connection_name_parts = self .connection_name .split ("." )
97
- if len (connection_name_parts ) != 3 :
98
- raise ValueError (
99
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
94
+
95
+ if self ._bq_connection_manager :
96
+ connection_name_parts = self .connection_name .split ("." )
97
+ if len (connection_name_parts ) != 3 :
98
+ raise ValueError (
99
+ f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
100
+ )
101
+ self ._bq_connection_manager .create_bq_connection (
102
+ project_id = connection_name_parts [0 ],
103
+ location = connection_name_parts [1 ],
104
+ connection_id = connection_name_parts [2 ],
105
+ iam_role = "aiplatform.user" ,
100
106
)
101
- self ._bq_connection_manager .create_bq_connection (
102
- project_id = connection_name_parts [0 ],
103
- location = connection_name_parts [1 ],
104
- connection_id = connection_name_parts [2 ],
105
- iam_role = "aiplatform.user" ,
106
- )
107
107
108
108
if self .model_name not in _TEXT_GENERATOR_ENDPOINTS :
109
109
raise ValueError (
@@ -289,12 +289,10 @@ def __init__(
289
289
self .model_name = model_name
290
290
self .version = version
291
291
self .session = session or bpd .get_global_session ()
292
- self ._bq_connection_manager = clients .BqConnectionManager (
293
- self .session .bqconnectionclient , self .session .resourcemanagerclient
294
- )
292
+ self ._bq_connection_manager = self .session .bqconnectionmanager
295
293
296
294
connection_name = connection_name or self .session ._bq_connection
297
- self .connection_name = self . _bq_connection_manager . resolve_full_connection_name (
295
+ self .connection_name = clients . resolve_full_bq_connection_name (
298
296
connection_name ,
299
297
default_project = self .session ._project ,
300
298
default_location = self .session ._location ,
@@ -309,17 +307,19 @@ def _create_bqml_model(self):
309
307
raise ValueError (
310
308
"Must provide connection_name, either in constructor or through session options."
311
309
)
312
- connection_name_parts = self .connection_name .split ("." )
313
- if len (connection_name_parts ) != 3 :
314
- raise ValueError (
315
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
310
+
311
+ if self ._bq_connection_manager :
312
+ connection_name_parts = self .connection_name .split ("." )
313
+ if len (connection_name_parts ) != 3 :
314
+ raise ValueError (
315
+ f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
316
+ )
317
+ self ._bq_connection_manager .create_bq_connection (
318
+ project_id = connection_name_parts [0 ],
319
+ location = connection_name_parts [1 ],
320
+ connection_id = connection_name_parts [2 ],
321
+ iam_role = "aiplatform.user" ,
316
322
)
317
- self ._bq_connection_manager .create_bq_connection (
318
- project_id = connection_name_parts [0 ],
319
- location = connection_name_parts [1 ],
320
- connection_id = connection_name_parts [2 ],
321
- iam_role = "aiplatform.user" ,
322
- )
323
323
324
324
if self .model_name not in _EMBEDDING_GENERATOR_ENDPOINTS :
325
325
raise ValueError (
@@ -437,12 +437,10 @@ def __init__(
437
437
connection_name : Optional [str ] = None ,
438
438
):
439
439
self .session = session or bpd .get_global_session ()
440
- self ._bq_connection_manager = clients .BqConnectionManager (
441
- self .session .bqconnectionclient , self .session .resourcemanagerclient
442
- )
440
+ self ._bq_connection_manager = self .session .bqconnectionmanager
443
441
444
442
connection_name = connection_name or self .session ._bq_connection
445
- self .connection_name = self . _bq_connection_manager . resolve_full_connection_name (
443
+ self .connection_name = clients . resolve_full_bq_connection_name (
446
444
connection_name ,
447
445
default_project = self .session ._project ,
448
446
default_location = self .session ._location ,
@@ -457,17 +455,19 @@ def _create_bqml_model(self):
457
455
raise ValueError (
458
456
"Must provide connection_name, either in constructor or through session options."
459
457
)
460
- connection_name_parts = self .connection_name .split ("." )
461
- if len (connection_name_parts ) != 3 :
462
- raise ValueError (
463
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
458
+
459
+ if self ._bq_connection_manager :
460
+ connection_name_parts = self .connection_name .split ("." )
461
+ if len (connection_name_parts ) != 3 :
462
+ raise ValueError (
463
+ f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
464
+ )
465
+ self ._bq_connection_manager .create_bq_connection (
466
+ project_id = connection_name_parts [0 ],
467
+ location = connection_name_parts [1 ],
468
+ connection_id = connection_name_parts [2 ],
469
+ iam_role = "aiplatform.user" ,
464
470
)
465
- self ._bq_connection_manager .create_bq_connection (
466
- project_id = connection_name_parts [0 ],
467
- location = connection_name_parts [1 ],
468
- connection_id = connection_name_parts [2 ],
469
- iam_role = "aiplatform.user" ,
470
- )
471
471
472
472
options = {"endpoint" : _GEMINI_PRO_ENDPOINT }
473
473
0 commit comments