Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more test, mypy fixes
  • Loading branch information
TrevorBergeron committed Nov 10, 2023
commit e614482339bf7e9745e34da84cbab7981e80d0dc
53 changes: 16 additions & 37 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
Tuple,
Union,
)
import uuid
import warnings

import google.api_core.client_info
Expand Down Expand Up @@ -505,7 +504,7 @@ def read_gbq_table(
api_name="read_gbq_table",
)

def _read_gbq_table_to_ibis_with_total_ordering(
def _get_snapshot_sql_and_primary_key(
self,
table_ref: bigquery.table.TableReference,
*,
Expand All @@ -518,15 +517,6 @@ def _read_gbq_table_to_ibis_with_total_ordering(
column(s), then return those too so that ordering generation can be
avoided.
"""
if table_ref.dataset_id.upper() == "_SESSION":
# _SESSION tables aren't supported by the tables.get REST API.
return (
self.ibis_client.sql(
f"SELECT * FROM `_SESSION`.`{table_ref.table_id}`"
),
None,
)

table_expression = self.ibis_client.table(
table_ref.table_id,
database=f"{table_ref.project}.{table_ref.dataset_id}",
Expand All @@ -551,22 +541,18 @@ def _read_gbq_table_to_ibis_with_total_ordering(
.get("columns")
)

if not primary_keys:
return table_expression, None
else:
# Read from a snapshot since we won't have to copy the table data to create a total ordering.
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
current_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, current_timestamp)
)
return table_expression, primary_keys
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
current_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, current_timestamp)
)
return table_expression, primary_keys

def _read_gbq_table(
self,
Expand All @@ -589,7 +575,7 @@ def _read_gbq_table(
(
table_expression,
total_ordering_cols,
) = self._read_gbq_table_to_ibis_with_total_ordering(
) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, enforce_region=True
)

Expand Down Expand Up @@ -843,7 +829,7 @@ def _read_pandas(
job_config.clustering_fields = cluster_cols
job_config.labels = {"bigframes-api": api_name}

load_table_destination = self._create_session_table()
load_table_destination = bigframes_io.random_table(self._anonymous_dataset)
load_job = self.bqclient.load_table_from_dataframe(
pandas_dataframe_copy,
load_table_destination,
Expand Down Expand Up @@ -1144,13 +1130,6 @@ def _check_file_size(self, filepath: str):
"for large files to avoid loading the file into local memory."
)

def _create_session_table(self) -> bigquery.TableReference:
table_name = f"{uuid.uuid4().hex}"
dataset = bigquery.Dataset(
bigquery.DatasetReference(self.bqclient.project, "_SESSION")
)
return dataset.table(table_name)

def _create_empty_temp_table(
self,
schema: Iterable[bigquery.SchemaField],
Expand Down Expand Up @@ -1457,7 +1436,7 @@ def _convert_to_string(column: ibis_types.Column) -> ibis_types.StringColumn:
# Some of these probably don't work
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check? Sounds like we need some targeted tests to cover the branches in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should all work now, added a test specifically for the json case, other datatypes are covered by existing tests.

col_type = column.type()
if col_type.is_array() or col_type.is_struct():
result = vendored_ibis_ops.ToJsonString(column).to_expr()
result = vendored_ibis_ops.ToJsonString(column).to_expr() # type: ignore
elif col_type.is_geospatial():
result = typing.cast(ibis_types.GeoSpatialColumn, column).as_text()
elif col_type.is_string():
Expand Down
104 changes: 54 additions & 50 deletions tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,58 +78,62 @@ def test_model_eval_with_data(penguins_bqml_linear_model, penguins_df_default_in

def test_model_centroids(penguins_bqml_kmeans_model: core.BqmlModel):
result = penguins_bqml_kmeans_model.centroids().to_pandas()
expected = pd.DataFrame(
{
"centroid_id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
"feature": [
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"sex",
]
* 3,
"numerical_value": [
47.509677,
14.993548,
217.040123,
pd.NA,
38.207813,
18.03125,
187.992188,
pd.NA,
47.036346,
18.834808,
197.1612,
pd.NA,
],
"categorical_value": [
[],
[],
[],
[
{"category": ".", "value": 0.008064516129032258},
{"category": "MALE", "value": 0.49193548387096775},
{"category": "FEMALE", "value": 0.47580645161290325},
{"category": "_null_filler", "value": 0.024193548387096774},
],
[],
[],
[],
[
{"category": "MALE", "value": 0.34375},
{"category": "FEMALE", "value": 0.625},
{"category": "_null_filler", "value": 0.03125},
expected = (
pd.DataFrame(
{
"centroid_id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
"feature": [
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"sex",
]
* 3,
"numerical_value": [
47.509677,
14.993548,
217.040123,
pd.NA,
38.207813,
18.03125,
187.992188,
pd.NA,
47.036346,
18.834808,
197.1612,
pd.NA,
],
[],
[],
[],
[
{"category": "MALE", "value": 0.6847826086956522},
{"category": "FEMALE", "value": 0.2826086956521739},
{"category": "_null_filler", "value": 0.03260869565217391},
"categorical_value": [
[],
[],
[],
[
{"category": ".", "value": 0.008064516129032258},
{"category": "MALE", "value": 0.49193548387096775},
{"category": "FEMALE", "value": 0.47580645161290325},
{"category": "_null_filler", "value": 0.024193548387096774},
],
[],
[],
[],
[
{"category": "MALE", "value": 0.34375},
{"category": "FEMALE", "value": 0.625},
{"category": "_null_filler", "value": 0.03125},
],
[],
[],
[],
[
{"category": "MALE", "value": 0.6847826086956522},
{"category": "FEMALE", "value": 0.2826086956521739},
{"category": "_null_filler", "value": 0.03260869565217391},
],
],
],
},
},
)
.sort_values(["centroid_id", "feature"])
.reset_index(drop=True)
)
pd.testing.assert_frame_equal(
result,
Expand Down