def fit(cfg_dict, x_train, y_train, x_test, y_test): # fit __spec__ = None model = Tree_Model(cfg_dict, x_train, y_train, x_test, y_test) model_name = cfg_dict['train']['model'] if model_name == 'lightgbm': bst = model.lgb_fit() elif model_name == 'xgboost': if cfg_dict['train']['if_grid_search'] == 'True': print('GS_CV......') model.GS_CV_xgb(int(cfg_dict['train']['grid_search_group'])) print('GS_CV finished!') return 0, 0, 0 bst = model.xgb_fit() elif model_name == 'catboost': bst = model.cat_fit() else: bst = 0 print('model name error') sys.exit() if bst == 'gscv': sys.exit() return bst
时间: 2024-04-18 15:30:08 浏览: 127
这段代码定义了一个名为 fit 的函数,用于训练模型。
首先,在函数中创建了一个 Tree_Model 的实例 model,该实例用于模型的训练和预测。
接着,代码从配置参数 cfg_dict 中获取了模型的名称,保存在 model_name 变量中。
然后,根据 model_name 的取值,选择不同的模型进行训练和预测。如果 model_name 为 'lightgbm',则调用 model 的 lgb_fit 方法进行 LightGBM 模型的训练;如果 model_name 为 'xgboost',则根据配置参数 cfg_dict['train']['if_grid_search'] 的取值判断是否进行网格搜索,若为 'True' 则调用 model 的 GS_CV_xgb 方法进行 XGBoost 模型的网格搜索交叉验证,否则调用 model 的 xgb_fit 方法进行 XGBoost 模型的训练;如果 model_name 为 'catboost',则调用 model 的 cat_fit 方法进行 CatBoost 模型的训练;否则打印错误信息并退出程序。
接下来,根据模型训练的结果,将训练好的模型保存在 bst 变量中。
最后,根据 bst 的取值判断是否进行了网格搜索交叉验证,若是则退出程序。
函数返回 bst 变量,即训练好的模型。
相关问题
def main(cfg_dict): df_0, df_1, df_9 = load_data(cfg_dict) # 设置训练集、测试集、仿真集 x_train, x_test, y_train, y_test, df_ft = set_data(df_0, df_1, df_9, cfg_dict) bst = fit(cfg_dict, x_train, y_train, x_test, y_test) # 查看模型中重要的特征 df_importances = feature_imp(model=bst, x_train=x_train, plot=False) df_importances.reset_index(drop=True, inplace=True)
这段代码定义了一个名为 main 的函数,用于主程序的执行。
首先,函数调用 load_data 函数,将配置参数 cfg_dict 传递给该函数,获取返回的三个数据框 df_0, df_1, df_9,分别表示类别为0、类别为1和类别为9的数据集。
接着,函数调用 set_data 函数,将 df_0, df_1, df_9 和 cfg_dict 作为参数传递给该函数,获取返回的训练集 x_train, 测试集 x_test, 训练集标签 y_train, 测试集标签 y_test,以及仿真集 df_ft。
然后,函数调用 fit 函数,将 cfg_dict, x_train, y_train, x_test, y_test 作为参数传递给该函数,获取训练好的模型 bst。
接下来,函数调用 feature_imp 函数,将 bst 和 x_train 作为参数传递给该函数,获取特征重要性信息,并将其保存到一个名为 df_importances 的数据框中。然后,通过调用 reset_index 方法重置 df_importances 的索引,使其从零开始。
这段代码的主要作用是加载数据、设置训练集、测试集和仿真集、训练模型,并获取模型中重要的特征信息。
if not os.path.exists('model/easy_feature_select.csv'): df_importances = df_importances[:150] df_importances.to_csv('model/easy_feature_select.csv', encoding='gbk', index=False) # 根据筛选后的特征重新加载数据 x_train, x_test, y_train, y_test, df_ft = set_data(df_0, df_1, df_9, cfg_dict) # 相关系数,补充未被筛选为重要特征但与重要特征相关性较大的其他特征 feature_list = x_train.columns.tolist() df_corr = x_train.corr() df_corr = df_corr.replace(1, 0) # 筛选出相关系数大于0.85的特征 for i in range(len(df_corr.columns)): if i >= len(df_corr.columns): break column = df_corr.columns[i] names = df_corr[abs(df_corr[column]) >= 0.85].index.tolist() if names: print(column, '的强相关特征:', names) feature_list = [i for i in feature_list if i not in names] df_corr = x_train[feature_list].corr() continue #feature_list = list(set(feature_list + ['呼叫次数', '入网时长(月)', # 'MOU_avg', 'DOU_avg', '省外流量占比_avg'])) df_feature = pd.DataFrame(feature_list, columns=['features']) df_importances = pd.merge(df_feature, df_importances, on='features', how='left') df_importances.to_csv('model/easy_feature_select.csv', encoding='gbk', index=False) # 根据筛选后的特征重新加载数据 x_train, x_test, y_train, y_test, df_ft = set_data(df_0, df_1, df_9, cfg_dict) # 重新训练 bst = fit(cfg_dict, x_train, y_train, x_test, y_test) df_importances = feature_imp(model=bst, x_train=x_train, plot=True) df_importances.to_csv('model/easy_feature_select.csv', encoding='gbk', index=False) # 根据重新排序的特征训练模型 x_train, x_test, y_train, y_test, df_ft = set_data(df_0, df_1, df_9, cfg_dict) bst = fit(cfg_dict, x_train, y_train, x_test, y_test)
这段代码用于特征选择和重新训练模型的过程。
首先,代码通过判断是否存在 'model/easy_feature_select.csv' 文件来判断是否需要进行特征选择。如果文件不存在,则进行特征选择的步骤。
在特征选择的过程中,首先将 df_importances 数据框的前150个特征保存到 'model/easy_feature_select.csv' 文件中,然后重新加载数据,获取新的训练集和测试集。
接着,计算特征之间的相关系数,并筛选出相关系数大于0.85的特征。如果存在相关系数大于0.85的特征,则将这些特征从特征列表 feature_list 中移除,并重新计算相关系数。
然后,将筛选后的特征列表 feature_list 和 df_importances 数据框进行合并,并将合并结果保存到 'model/easy_feature_select.csv' 文件中。再次重新加载数据,获取新的训练集和测试集。
接下来,重新训练模型,并将训练好的模型保存到 bst 变量中。然后,通过调用 feature_imp 函数获取新模型的特征重要性信息,并将其保存到 'model/easy_feature_select.csv' 文件中。
最后,再次重新加载数据,获取新的训练集和测试集。重新训练模型,并将训练好的模型保存到 bst 变量中。
这段代码的作用是进行特征选择和重新训练模型的过程。在特征选择中,通过计算特征之间的相关系数,筛选出与重要特征相关性较大的其他特征,并将这些特征从特征列表中移除。然后,重新训练模型,并保存新模型的特征重要性信息。最后,再次重新加载数据,重新训练模型。
阅读全文
相关推荐















