Skip to content

feat: add ARIMAPlus.predict parameters #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def generate_text_embedding(
),
)

def forecast(self) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast()
def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
Expand Down
21 changes: 19 additions & 2 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,38 @@ def _fit(
options=self._bqml_options,
)

def predict(self, X=None) -> bpd.DataFrame:
def predict(
self, X=None, horizon: int = 3, confidence_level: float = 0.95
) -> bpd.DataFrame:
"""Predict the closest cluster for each sample in X.

Args:
X (default None):
ignored, to be compatible with other APIs.
horizon (int, default: 3):
an int value that specifies the number of time points to forecast.
The default value is 3, and the maximum value is 1000.
confidence_level (float, default 0.95):
a float value that specifies percentage of the future values that fall in the prediction interval.
The valid input range is [0.0, 1.0).

Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames. Which
contains 2 columns "forecast_timestamp" and "forecast_value".
"""
if horizon < 1 or horizon > 1000:
raise ValueError(f"horizon must be [1, 1000], but is {horizon}.")
if confidence_level < 0.0 or confidence_level >= 1.0:
raise ValueError(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.forecast()
return self._bqml_model.forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

def score(
self,
Expand Down
6 changes: 4 additions & 2 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,11 @@ def ml_predict(self, source_df: bpd.DataFrame) -> str:
return f"""SELECT * FROM ML.PREDICT(MODEL `{self._model_name}`,
({self._source_sql(source_df)}))"""

def ml_forecast(self) -> str:
def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`)"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`,
{struct_options_sql})"""

def ml_generate_text(
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
Expand Down
9 changes: 5 additions & 4 deletions tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,18 @@ def test_model_generate_text(

def test_model_forecast(time_series_bqml_arima_plus_model: core.BqmlModel):
utc = pytz.utc
forecast = time_series_bqml_arima_plus_model.forecast().to_pandas()[
["forecast_timestamp", "forecast_value"]
]
forecast = time_series_bqml_arima_plus_model.forecast(
{"horizon": 4, "confidence_level": 0.8}
).to_pandas()[["forecast_timestamp", "forecast_value"]]
expected = pd.DataFrame(
{
"forecast_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
datetime(2017, 8, 5, tzinfo=utc),
],
"forecast_value": [2724.472284, 2593.368389, 2353.613034],
"forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071],
}
)
expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype())
Expand Down
43 changes: 40 additions & 3 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import pyarrow as pa
import pytz

from bigframes.ml import forecasting

def test_model_predict(time_series_arima_plus_model):

def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict().to_pandas()
assert predictions.shape == (3, 8)
Expand Down Expand Up @@ -47,7 +49,40 @@ def test_model_predict(time_series_arima_plus_model):
)


def test_model_score(time_series_arima_plus_model, new_time_series_df):
def test_model_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict(
horizon=4, confidence_level=0.9
).to_pandas()
assert predictions.shape == (4, 8)
result = predictions[["forecast_timestamp", "forecast_value"]]
expected = pd.DataFrame(
{
"forecast_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
datetime(2017, 8, 5, tzinfo=utc),
],
"forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071],
}
)
expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype())
expected["forecast_timestamp"] = expected["forecast_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_model_score(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.score(
new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]]
).to_pandas()
Expand All @@ -69,7 +104,9 @@ def test_model_score(time_series_arima_plus_model, new_time_series_df):
)


def test_model_score_series(time_series_arima_plus_model, new_time_series_df):
def test_model_score_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.score(
new_time_series_df["parsed_date"], new_time_series_df["total_visits"]
).to_pandas()
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ def test_ml_centroids_produces_correct_sql(
)


def test_forecast_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_forecast(
struct_options={"option_key1": 1, "option_key2": 2.2},
)
assert (
sql
== """SELECT * FROM ML.FORECAST(MODEL `my_project_id.my_dataset_id.my_model_id`,
STRUCT(
1 AS option_key1,
2.2 AS option_key2))"""
)


def test_ml_generate_text_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
Expand Down