@@ -55,25 +55,28 @@ def test_create_text_generator_model_default_session(
5555):
5656 import bigframes .pandas as bpd
5757
58- bpd .close_session ()
59- bpd .options .bigquery .bq_connection = bq_connection
60- bpd .options .bigquery .location = "us"
61-
62- model = llm .PaLM2TextGenerator ()
63- assert model is not None
64- assert model ._bqml_model is not None
65- assert (
66- model .connection_name .casefold ()
67- == f"{ bigquery_client .project } .us.bigframes-rf-conn"
68- )
69-
70- llm_text_df = bpd .read_pandas (llm_text_pandas_df )
71-
72- df = model .predict (llm_text_df ).to_pandas ()
73- assert df .shape == (3 , 4 )
74- assert "ml_generate_text_llm_result" in df .columns
75- series = df ["ml_generate_text_llm_result" ]
76- assert all (series .str .len () > 20 )
58+ # Note: This starts a thread-local session.
59+ with bpd .option_context (
60+ "bigquery.bq_connection" ,
61+ bq_connection ,
62+ "bigquery.location" ,
63+ "US" ,
64+ ):
65+ model = llm .PaLM2TextGenerator ()
66+ assert model is not None
67+ assert model ._bqml_model is not None
68+ assert (
69+ model .connection_name .casefold ()
70+ == f"{ bigquery_client .project } .us.bigframes-rf-conn"
71+ )
72+
73+ llm_text_df = bpd .read_pandas (llm_text_pandas_df )
74+
75+ df = model .predict (llm_text_df ).to_pandas ()
76+ assert df .shape == (3 , 4 )
77+ assert "ml_generate_text_llm_result" in df .columns
78+ series = df ["ml_generate_text_llm_result" ]
79+ assert all (series .str .len () > 20 )
7780
7881
7982@pytest .mark .flaky (retries = 2 )
@@ -82,25 +85,28 @@ def test_create_text_generator_32k_model_default_session(
8285):
8386 import bigframes .pandas as bpd
8487
85- bpd .close_session ()
86- bpd .options .bigquery .bq_connection = bq_connection
87- bpd .options .bigquery .location = "us"
88-
89- model = llm .PaLM2TextGenerator (model_name = "text-bison-32k" )
90- assert model is not None
91- assert model ._bqml_model is not None
92- assert (
93- model .connection_name .casefold ()
94- == f"{ bigquery_client .project } .us.bigframes-rf-conn"
95- )
96-
97- llm_text_df = bpd .read_pandas (llm_text_pandas_df )
98-
99- df = model .predict (llm_text_df ).to_pandas ()
100- assert df .shape == (3 , 4 )
101- assert "ml_generate_text_llm_result" in df .columns
102- series = df ["ml_generate_text_llm_result" ]
103- assert all (series .str .len () > 20 )
88+ # Note: This starts a thread-local session.
89+ with bpd .option_context (
90+ "bigquery.bq_connection" ,
91+ bq_connection ,
92+ "bigquery.location" ,
93+ "US" ,
94+ ):
95+ model = llm .PaLM2TextGenerator (model_name = "text-bison-32k" )
96+ assert model is not None
97+ assert model ._bqml_model is not None
98+ assert (
99+ model .connection_name .casefold ()
100+ == f"{ bigquery_client .project } .us.bigframes-rf-conn"
101+ )
102+
103+ llm_text_df = bpd .read_pandas (llm_text_pandas_df )
104+
105+ df = model .predict (llm_text_df ).to_pandas ()
106+ assert df .shape == (3 , 4 )
107+ assert "ml_generate_text_llm_result" in df .columns
108+ series = df ["ml_generate_text_llm_result" ]
109+ assert all (series .str .len () > 20 )
104110
105111
106112@pytest .mark .flaky (retries = 2 )
@@ -232,27 +238,33 @@ def test_create_embedding_generator_multilingual_model(
232238def test_create_text_embedding_generator_model_defaults (bq_connection ):
233239 import bigframes .pandas as bpd
234240
235- bpd .close_session ()
236- bpd .options .bigquery .bq_connection = bq_connection
237- bpd .options .bigquery .location = "us"
238-
239- model = llm .PaLM2TextEmbeddingGenerator ()
240- assert model is not None
241- assert model ._bqml_model is not None
241+ # Note: This starts a thread-local session.
242+ with bpd .option_context (
243+ "bigquery.bq_connection" ,
244+ bq_connection ,
245+ "bigquery.location" ,
246+ "US" ,
247+ ):
248+ model = llm .PaLM2TextEmbeddingGenerator ()
249+ assert model is not None
250+ assert model ._bqml_model is not None
242251
243252
244253def test_create_text_embedding_generator_multilingual_model_defaults (bq_connection ):
245254 import bigframes .pandas as bpd
246255
247- bpd .close_session ()
248- bpd .options .bigquery .bq_connection = bq_connection
249- bpd .options .bigquery .location = "us"
250-
251- model = llm .PaLM2TextEmbeddingGenerator (
252- model_name = "textembedding-gecko-multilingual"
253- )
254- assert model is not None
255- assert model ._bqml_model is not None
256+ # Note: This starts a thread-local session.
257+ with bpd .option_context (
258+ "bigquery.bq_connection" ,
259+ bq_connection ,
260+ "bigquery.location" ,
261+ "US" ,
262+ ):
263+ model = llm .PaLM2TextEmbeddingGenerator (
264+ model_name = "textembedding-gecko-multilingual"
265+ )
266+ assert model is not None
267+ assert model ._bqml_model is not None
256268
257269
258270@pytest .mark .flaky (retries = 2 )
0 commit comments