File tree 2 files changed +11
-2
lines changed
2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -77,7 +77,8 @@ def _apply_ml_tvf(
77
77
78
78
result_sql = apply_sql_tvf (input_sql )
79
79
df = self ._session .read_gbq (result_sql , index_col = index_col_ids )
80
- df .index .names = index_labels
80
+ if df ._has_index :
81
+ df .index .names = index_labels
81
82
# Restore column labels
82
83
df .rename (
83
84
columns = {
Original file line number Diff line number Diff line change @@ -111,7 +111,7 @@ def test_linear_regression_customized_params_fit_score(
111
111
assert reloaded_model .learning_rate == 0.2
112
112
113
113
114
- def test_unordered_mode_regression_configure_fit_score (
114
+ def test_unordered_mode_linear_regression_configure_fit_score_predict (
115
115
unordered_session , penguins_table_id , dataset_id
116
116
):
117
117
model = bigframes .ml .linear_model .LinearRegression ()
@@ -154,6 +154,14 @@ def test_unordered_mode_regression_configure_fit_score(
154
154
assert reloaded_model .max_iterations == 20
155
155
assert reloaded_model .tol == 0.01
156
156
157
+ pred = reloaded_model .predict (df )
158
+ utils .check_pandas_df_schema_and_index (
159
+ pred ,
160
+ columns = ("predicted_body_mass_g" ,),
161
+ col_exact = False ,
162
+ index = 334 ,
163
+ )
164
+
157
165
158
166
# TODO(garrettwu): add tests for param warm_start. Requires a trained model.
159
167
You can’t perform that action at this time.
0 commit comments