File tree 1 file changed +15
-1
lines changed
1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -80,7 +80,21 @@ def test_logistic_regression_prediction(random_model_id: str) -> None:
80
80
X = training_data .drop (columns = ["income_bracket" , "dataframe" ])
81
81
y = training_data ["income_bracket" ]
82
82
83
- census_model = bigframes .ml .linear_model .LogisticRegression ()
83
+ census_model = bigframes .ml .linear_model .LogisticRegression (
84
+ # Balance the class labels in the training data by setting
85
+ # class_weight="balanced".
86
+ #
87
+ # By default, the training data is unweighted. If the labels
88
+ # in the training data are imbalanced, the model may learn to
89
+ # predict the most popular class of labels more heavily. In
90
+ # this case, most of the respondents in the dataset are in the
91
+ # lower income bracket. This may lead to a model that predicts
92
+ # the lower income bracket too heavily. Class weights balance
93
+ # the class labels by calculating the weights for each class in
94
+ # inverse proportion to the frequency of that class.
95
+ class_weight = "balanced" ,
96
+ max_iterations = 15 ,
97
+ )
84
98
census_model .fit (X , y )
85
99
86
100
census_model .to_gbq (
You can’t perform that action at this time.
0 commit comments