Skip to content

Commit b951549

Browse files
authored
docs: use class_weight="balanced" in the logistic regression prediction tutorial (#678)
This aligns the Python code with the SQL at https://cloud.google.com/bigquery/docs/logistic-regression-prediction#create_a_logistic_regression_model ```sql CREATE OR REPLACE MODEL `census.census_model` OPTIONS ( model_type='LOGISTIC_REG', auto_class_weights=TRUE, data_split_method='NO_SPLIT', input_label_cols=['income_bracket'], max_iterations=15) AS SELECT * EXCEPT(dataframe) FROM `census.input_data` WHERE dataframe = 'training' ```
1 parent a58dcd2 commit b951549

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

samples/snippets/logistic_regression_prediction_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,21 @@ def test_logistic_regression_prediction(random_model_id: str) -> None:
8080
X = training_data.drop(columns=["income_bracket", "dataframe"])
8181
y = training_data["income_bracket"]
8282

83-
census_model = bigframes.ml.linear_model.LogisticRegression()
83+
census_model = bigframes.ml.linear_model.LogisticRegression(
84+
# Balance the class labels in the training data by setting
85+
# class_weight="balanced".
86+
#
87+
# By default, the training data is unweighted. If the labels
88+
# in the training data are imbalanced, the model may learn to
89+
# predict the most popular class of labels more heavily. In
90+
# this case, most of the respondents in the dataset are in the
91+
# lower income bracket. This may lead to a model that predicts
92+
# the lower income bracket too heavily. Class weights balance
93+
# the class labels by calculating the weights for each class in
94+
# inverse proportion to the frequency of that class.
95+
class_weight="balanced",
96+
max_iterations=15,
97+
)
8498
census_model.fit(X, y)
8599

86100
census_model.to_gbq(

0 commit comments

Comments
 (0)