Skip to content

Commit f7d52d9

Browse files
authored
fix: arima model series input. (#1237)
1 parent 0d84459 commit f7d52d9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

bigframes/ml/forecasting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,15 @@ def _fit(
199199
Returns:
200200
ARIMAPlus: Fitted estimator.
201201
"""
202+
X, y = utils.batch_convert_to_dataframe(X, y)
203+
202204
if X.columns.size != 1:
203205
raise ValueError(
204206
"Time series timestamp input X must only contain 1 column."
205207
)
206208
if y.columns.size != 1:
207209
raise ValueError("Time series data input y must only contain 1 column.")
208210

209-
X, y = utils.batch_convert_to_dataframe(X, y)
210-
211211
self._bqml_model = self._bqml_model_factory.create_time_series_model(
212212
X,
213213
y,

tests/system/large/ml/test_forecasting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
@pytest.fixture(scope="module")
3737
def arima_model(time_series_df_default_index):
3838
model = forecasting.ARIMAPlus()
39-
X_train = time_series_df_default_index[["parsed_date"]]
39+
X_train = time_series_df_default_index["parsed_date"]
4040
y_train = time_series_df_default_index[["total_visits"]]
4141
model.fit(X_train, y_train)
4242
return model
@@ -114,7 +114,7 @@ def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id):
114114
)
115115

116116
X_train = time_series_df_default_index[["parsed_date"]]
117-
y_train = time_series_df_default_index[["total_visits"]]
117+
y_train = time_series_df_default_index["total_visits"]
118118
model.fit(X_train, y_train)
119119

120120
# save, load to ensure configuration was kept

0 commit comments

Comments
 (0)