Skip to content

Commit f959b65

Browse files
authored
feat: add ml LogisticRegression model params (#481)
* feat: add ml LogisticRegression model params * fix tests * fix tests
1 parent 352cb85 commit f959b65

File tree

5 files changed

+123
-21
lines changed

5 files changed

+123
-21
lines changed

bigframes/ml/linear_model.py

+58-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
"learn_rate_strategy": "learnRateStrategy",
3939
"learn_rate": "learnRate",
4040
"early_stop": "earlyStop",
41+
# To rename to tol.
4142
"min_rel_progress": "minRelativeProgress",
43+
"tol": "minRelativeProgress",
4244
"ls_init_learn_rate": "initialLearnRate",
4345
"warm_start": "warmStart",
4446
"calculate_p_values": "calculatePValues",
@@ -59,7 +61,7 @@ def __init__(
5961
*,
6062
optimize_strategy: Literal[
6163
"auto_strategy", "batch_gradient_descent", "normal_equation"
62-
] = "normal_equation",
64+
] = "auto_strategy",
6365
fit_intercept: bool = True,
6466
l1_reg: Optional[float] = None,
6567
l2_reg: float = 0.0,
@@ -139,7 +141,7 @@ def _bqml_options(self) -> dict:
139141
if self.ls_init_learn_rate is not None:
140142
options["ls_init_learn_rate"] = self.ls_init_learn_rate
141143
# Even presenting warm_start returns error for NORMAL_EQUATION optimizer
142-
if self.warm_start is True:
144+
if self.warm_start:
143145
options["warm_start"] = self.warm_start
144146

145147
return options
@@ -212,10 +214,34 @@ class LogisticRegression(
212214
def __init__(
213215
self,
214216
*,
217+
optimize_strategy: Literal[
218+
"auto_strategy", "batch_gradient_descent", "normal_equation"
219+
] = "auto_strategy",
215220
fit_intercept: bool = True,
221+
l1_reg: Optional[float] = None,
222+
l2_reg: float = 0.0,
223+
max_iterations: int = 20,
224+
warm_start: bool = False,
225+
learn_rate: Optional[float] = None,
226+
learn_rate_strategy: Literal["line_search", "constant"] = "line_search",
227+
tol: float = 0.01,
228+
ls_init_learn_rate: Optional[float] = None,
229+
calculate_p_values: bool = False,
230+
enable_global_explain: bool = False,
216231
class_weights: Optional[Union[Literal["balanced"], Dict[str, float]]] = None,
217232
):
233+
self.optimize_strategy = optimize_strategy
218234
self.fit_intercept = fit_intercept
235+
self.l1_reg = l1_reg
236+
self.l2_reg = l2_reg
237+
self.max_iterations = max_iterations
238+
self.warm_start = warm_start
239+
self.learn_rate = learn_rate
240+
self.learn_rate_strategy = learn_rate_strategy
241+
self.tol = tol
242+
self.ls_init_learn_rate = ls_init_learn_rate
243+
self.calculate_p_values = calculate_p_values
244+
self.enable_global_explain = enable_global_explain
219245
self.class_weights = class_weights
220246
self._auto_class_weight = class_weights == "balanced"
221247
self._bqml_model: Optional[core.BqmlModel] = None
@@ -231,8 +257,16 @@ def _from_bq(
231257

232258
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
233259
last_fitting = model.training_runs[-1]["trainingOptions"]
234-
if "fitIntercept" in last_fitting:
235-
kwargs["fit_intercept"] = last_fitting["fitIntercept"]
260+
dummy_logistic = cls()
261+
for bf_param, bf_value in dummy_logistic.__dict__.items():
262+
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
263+
if bqml_param in last_fitting:
264+
# Convert types
265+
kwargs[bf_param] = (
266+
float(last_fitting[bqml_param])
267+
if bf_param in ["l1_reg", "learn_rate", "ls_init_learn_rate"]
268+
else type(bf_value)(last_fitting[bqml_param])
269+
)
236270
if last_fitting["autoClassWeights"]:
237271
kwargs["class_weights"] = "balanced"
238272
# TODO(ashleyxu) support class_weights in the constructor.
@@ -244,16 +278,34 @@ def _from_bq(
244278
return new_logistic_regression
245279

246280
@property
247-
def _bqml_options(self) -> Dict[str, str | int | float | List[str]]:
281+
def _bqml_options(self) -> dict:
248282
"""The model options as they will be set for BQML"""
249-
return {
283+
options = {
250284
"model_type": "LOGISTIC_REG",
251285
"data_split_method": "NO_SPLIT",
252286
"fit_intercept": self.fit_intercept,
253287
"auto_class_weights": self._auto_class_weight,
288+
"optimize_strategy": self.optimize_strategy,
289+
"l2_reg": self.l2_reg,
290+
"max_iterations": self.max_iterations,
291+
"learn_rate_strategy": self.learn_rate_strategy,
292+
"min_rel_progress": self.tol,
293+
"calculate_p_values": self.calculate_p_values,
294+
"enable_global_explain": self.enable_global_explain,
254295
# TODO(ashleyxu): support class_weights (struct array as dict in our API)
255296
# "class_weights": self.class_weights,
256297
}
298+
if self.l1_reg is not None:
299+
options["l1_reg"] = self.l1_reg
300+
if self.learn_rate is not None:
301+
options["learn_rate"] = self.learn_rate
302+
if self.ls_init_learn_rate is not None:
303+
options["ls_init_learn_rate"] = self.ls_init_learn_rate
304+
# Even presenting warm_start returns error for NORMAL_EQUATION optimizer
305+
if self.warm_start:
306+
options["warm_start"] = self.warm_start
307+
308+
return options
257309

258310
def _fit(
259311
self,

tests/system/large/ml/test_linear_model.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,15 @@ def test_logistic_regression_customized_params_fit_score(
184184
penguins_df_default_index, dataset_id
185185
):
186186
model = bigframes.ml.linear_model.LogisticRegression(
187-
fit_intercept=False, class_weights="balanced"
187+
fit_intercept=False,
188+
class_weights="balanced",
189+
l2_reg=0.2,
190+
tol=0.02,
191+
l1_reg=0.2,
192+
max_iterations=30,
193+
optimize_strategy="batch_gradient_descent",
194+
learn_rate_strategy="constant",
195+
learn_rate=0.2,
188196
)
189197
df = penguins_df_default_index.dropna()
190198
X_train = df[
@@ -203,12 +211,12 @@ def test_logistic_regression_customized_params_fit_score(
203211
result = model.score(X_train, y_train).to_pandas()
204212
expected = pd.DataFrame(
205213
{
206-
"precision": [0.58483],
207-
"recall": [0.586616],
208-
"accuracy": [0.877246],
209-
"f1_score": [0.58571],
210-
"log_loss": [1.032699],
211-
"roc_auc": [0.924132],
214+
"precision": [0.487],
215+
"recall": [0.602],
216+
"accuracy": [0.464],
217+
"f1_score": [0.379],
218+
"log_loss": [0.972],
219+
"roc_auc": [0.700],
212220
},
213221
dtype="Float64",
214222
)
@@ -223,5 +231,15 @@ def test_logistic_regression_customized_params_fit_score(
223231
f"{dataset_id}.temp_configured_logistic_reg_model"
224232
in reloaded_model._bqml_model.model_name
225233
)
234+
# TODO(garrettwu) optimize_strategy isn't logged in BQML
235+
# assert reloaded_model.optimize_strategy == "BATCH_GRADIENT_DESCENT"
226236
assert reloaded_model.fit_intercept is False
227-
assert reloaded_model.class_weights == "balanced"
237+
assert reloaded_model.calculate_p_values is False
238+
assert reloaded_model.enable_global_explain is False
239+
assert reloaded_model.l1_reg == 0.2
240+
assert reloaded_model.l2_reg == 0.2
241+
assert reloaded_model.ls_init_learn_rate is None
242+
assert reloaded_model.max_iterations == 30
243+
assert reloaded_model.tol == 0.02
244+
assert reloaded_model.learn_rate_strategy == "CONSTANT"
245+
assert reloaded_model.learn_rate == 0.2

tests/unit/ml/test_golden_sql.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_linear_regression_default_fit(
105105
model.fit(mock_X, mock_y)
106106

107107
mock_session._start_query_ml_ddl.assert_called_once_with(
108-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
108+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
109109
)
110110

111111

@@ -115,7 +115,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
115115
model.fit(mock_X, mock_y)
116116

117117
mock_session._start_query_ml_ddl.assert_called_once_with(
118-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
118+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
119119
)
120120

121121

@@ -148,21 +148,29 @@ def test_logistic_regression_default_fit(
148148
model.fit(mock_X, mock_y)
149149

150150
mock_session._start_query_ml_ddl.assert_called_once_with(
151-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
151+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy="auto_strategy",\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
152152
)
153153

154154

155155
def test_logistic_regression_params_fit(
156156
bqml_model_factory, mock_session, mock_X, mock_y
157157
):
158158
model = linear_model.LogisticRegression(
159-
fit_intercept=False, class_weights="balanced"
159+
fit_intercept=False,
160+
class_weights="balanced",
161+
l2_reg=0.2,
162+
tol=0.02,
163+
l1_reg=0.2,
164+
max_iterations=30,
165+
optimize_strategy="batch_gradient_descent",
166+
learn_rate_strategy="constant",
167+
learn_rate=0.2,
160168
)
161169
model._bqml_model_factory = bqml_model_factory
162170
model.fit(mock_X, mock_y)
163171

164172
mock_session._start_query_ml_ddl.assert_called_once_with(
165-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
173+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy="batch_gradient_descent",\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy="constant",\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
166174
)
167175

168176

third_party/bigframes_vendored/sklearn/linear_model/_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ class LinearRegression(RegressorMixin, LinearModel):
6363
the dataset, and the targets predicted by the linear approximation.
6464
6565
Args:
66-
optimize_strategy (str, default "normal_equation"):
66+
optimize_strategy (str, default "auto_strategy"):
6767
The strategy to train linear regression models. Possible values are
6868
"auto_strategy", "batch_gradient_descent", "normal_equation". Default
69-
to "normal_equation".
69+
to "auto_strategy".
7070
fit_intercept (bool, default True):
7171
Default ``True``. Whether to calculate the intercept for this
7272
model. If set to False, no intercept will be used in calculations

third_party/bigframes_vendored/sklearn/linear_model/_logistic.py

+24
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class LogisticRegression(LinearClassifierMixin, BaseEstimator):
2424
"""Logistic Regression (aka logit, MaxEnt) classifier.
2525
2626
Args:
27+
optimize_strategy (str, default "auto_strategy"):
28+
The strategy to train logistic regression models. Possible values are
29+
"auto_strategy", "batch_gradient_descent", "normal_equation". Default
30+
to "auto_strategy".
2731
fit_intercept (default True):
2832
Default True. Specifies if a constant (a.k.a. bias or intercept)
2933
should be added to the decision function.
@@ -35,6 +39,26 @@ class LogisticRegression(LinearClassifierMixin, BaseEstimator):
3539
frequencies in the input data as
3640
``n_samples / (n_classes * np.bincount(y))``. Dict isn't
3741
supported now.
42+
l1_reg (float or None, default None):
43+
The amount of L1 regularization applied. Default to None. Can't be set in "normal_equation" mode. If unset, value 0 is used.
44+
l2_reg (float, default 0.0):
45+
The amount of L2 regularization applied. Default to 0.
46+
max_iterations (int, default 20):
47+
The maximum number of training iterations or steps. Default to 20.
48+
warm_start (bool, default False):
49+
Determines whether to train a model with new training data, new model options, or both. Unless you explicitly override them, the initial options used to train the model are used for the warm start run. Default to False.
50+
learn_rate (float or None, default None):
51+
The learn rate for gradient descent when learn_rate_strategy='constant'. If unset, value 0.1 is used. If learn_rate_strategy='line_search', an error is returned.
52+
learn_rate_strategy (str, default "line_search"):
53+
The strategy for specifying the learning rate during training. Default to "line_search".
54+
tol (float, default 0.01):
55+
The minimum relative loss improvement that is necessary to continue training when EARLY_STOP is set to true. For example, a value of 0.01 specifies that each iteration must reduce the loss by 1% for training to continue. Default to 0.01.
56+
ls_init_learn_rate (float or None, default None):
57+
Sets the initial learning rate that learn_rate_strategy='line_search' uses. This option can only be used if line_search is specified. If unset, value 0.1 is used.
58+
calculate_p_values (bool, default False):
59+
Specifies whether to compute p-values and standard errors during training. Default to False.
60+
enable_global_explain (bool, default False):
61+
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
3862
"""
3963

4064
def fit(

0 commit comments

Comments
 (0)