ConfusionMatrixDisplay,在图中添加title,修改label等的方法

本文介绍了如何在Python的sklearn库中使用ConfusionMatrixDisplay绘制混淆矩阵,并提供了解决其隐藏ax问题的方法。通过外部ax添加标题和调整精度,展示了研究生阶段对这一主题的深入理解。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

记录一下关于python使用sklearn中的混淆矩阵库ConfusionMatrixDisplay,在图中添加title,修改label等的方法

本科没有写什么,研究生期间第一篇记录

由于ConfusionMatrixDisplay库会隐式调用plt.subplots(),导致用户不能获取到figure或者ax不能使用内部函数进行修改,通过观察源码,我们发现ConfusionMatrixDisplay类中有一个plot函数(这也是我们绘制混淆矩阵所需要调用的),函数的参数表有ax变量。
ConfusionMatrixDisplay类内函数plot的参数表
相信有些朋友已经知道如何解决了。
是的通过传入外部的ax,使用外部的ax添加其他内容即可,代码如下:

confusion_matrix_figure=ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=[0,1])
ax=plt.figure().subplots()
accuracy = (cm[0, 0] + cm[1, 1]) * 1.0 / np.sum(cm)
ax.set(title="Accuracy = %0.2f" % accuracy)
confusion_matrix_figure.plot(ax=ax)
plt.show()
在我的后端中还有两个文件,一个是new_algorithm.py:import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import RandomForestClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.neural_network import MLPClassifier from sklearn.metrics import accuracy_score, classification_report from sklearn.preprocessing import StandardScaler from sklearn.pipeline import make_pipeline from sklearn.base import clone import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay import os import re from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer import seaborn as sns # 添加数据插补器 from sklearn.impute import SimpleImputer def check_chinese_font_support(): """检查系统是否支持中文字体""" chinese_fonts = ['SimHei', 'WenQuanYi Micro Hei', 'Heiti TC', 'Microsoft YaHei', 'SimSun'] for font in chinese_fonts: try: plt.rcParams["font.family"] = font # 测试字体是否可用 fig, ax = plt.subplots(figsize=(1, 1)) ax.text(0.5, 0.5, '测试', fontsize=12) plt.close(fig) print(f"系统支持中文字体: {font}") return True except: continue print("系统不支持中文字体,将使用英文标签") plt.rcParams["font.family"] = ['Arial', 'sans-serif'] return False class GasSensorDataAnalyzer: """有害气体分类数据加载与预处理类""" def __init__(self): # 基础气体标签 self.gas_labels = { 'acetone': 0, 'toluene': 1, 'methanol': 2, 'formaldehyde': 3, 'ethanol': 4 } # 混合气体标签生成(每个混合气体用唯一数字标识) self.mixture_labels = self._generate_mixture_labels() # 合并所有气体标签 self.all_gas_labels = {**self.gas_labels, **self.mixture_labels} # 中英文气体名称映射 self.gas_names = { 'acetone': {'cn': '丙酮', 'en': 'Acetone'}, 'toluene': {'cn': '甲苯', 'en': 'Toluene'}, 'methanol': {'cn': '甲醇', 'en': 'Methanol'}, 'formaldehyde': {'cn': '甲醛', 'en': 'Formaldehyde'}, 'ethanol': {'cn': '乙醇', 'en': 'Ethanol'}, 'toluene+formaldehyde': {'cn': '甲苯+甲醛', 'en': 'Toluene+Formaldehyde'}, 'methanol+toluene+formaldehyde': {'cn': '甲醇+甲苯+甲醛', 'en': 'Methanol+Toluene+Formaldehyde'} # 可以根据需要添加更多混合气体的名称映射 } # 传感器类型映射 self.sensor_types = { 'MP2': 0, 'MP3B': 1, 'MP503': 2, 'MP801': 3, 'MQ2': 4, 'MQ7B': 5 } # 初始化多维度类别映射 self.multi_dimension_labels = {} self.next_label_id = 0 # 传感器中英文名称映射 self.sensor_names = { 'MP2': {'cn': 'MP2', 'en': 'MP2'}, 'MP3B': {'cn': 'MP3B', 'en': 'MP3B'}, 'MP503': {'cn': 'MP503', 'en': 'MP503'}, 'MP801': {'cn': 'MP801', 'en': 'MP801'}, 'MQ2': {'cn': 'MQ2', 'en': 'MQ2'}, 'MQ7B': {'cn': 'MQ7B', 'en': 'MQ7B'} } def _generate_mixture_labels(self): """生成混合气体的标签映射""" # 定义可能的混合气体组合 mixtures = [ 'toluene+formaldehyde', 'methanol+toluene+formaldehyde' # 可以根据需要添加更多混合气体组合 ] # 为每个混合气体分配唯一标签(从基础气体标签之后开始) next_label = max(self.gas_labels.values()) + 1 return {mixture: next_label + i for i, mixture in enumerate(mixtures)} def get_or_create_multi_dimension_label(self, sensor_type, gas_type, concentration): """ 获取或创建多维度类别标签 参数: - sensor_type: 传感器类型 - gas_type: 气体类型 - concentration: 浓度值 返回: - 标签ID和标签名称 """ # 创建唯一键 key = f"{sensor_type}_{gas_type}_{concentration}ppm" # 如果键不存在,创建新标签 if key not in self.multi_dimension_labels: self.multi_dimension_labels[key] = self.next_label_id self.next_label_id += 1 # 返回标签ID和标签名称 label_id = self.multi_dimension_labels[key] # 创建中英文标签名称 sensor_name_cn = self.sensor_names.get(sensor_type, {}).get('cn', sensor_type) sensor_name_en = self.sensor_names.get(sensor_type, {}).get('en', sensor_type) gas_name_cn = self.gas_names.get(gas_type, {}).get('cn', gas_type) gas_name_en = self.gas_names.get(gas_type, {}).get('en', gas_type) label_name_cn = f"{sensor_name_cn}_{gas_name_cn}_{concentration}ppm" label_name_en = f"{sensor_name_en}_{gas_name_en}_{concentration}ppm" return label_id, { 'cn': label_name_cn, 'en': label_name_en } def load_single_gas_data(self, file_path, gas_type, concentration, sensor_type): """ 加载单一气体数据 参数: - file_path: 文件路径 - gas_type: 气体类型 (如 'acetone', 'toluene' 等) - concentration: 浓度值 (如 20, 30, 50 等) - sensor_type: 传感器类型 (如 'MP2', 'MP801' 等) """ try: if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") df = pd.read_excel(file_path, sheet_name='Sheet1', index_col=0) X = df.values # 尝试将数据转换为 float 类型 try: X = X.astype(float) except ValueError: print("警告: 数据中包含非数值类型,将过滤掉非数值类型的数据") numeric_mask = np.vectorize(np.isreal)(X) X = X[numeric_mask].reshape(-1, df.shape[1]) # 检查并报告NaN值 nan_count = np.isnan(X).sum() if nan_count > 0: print(f"警告: 数据中包含 {nan_count} 个NaN值") # 可选:替换NaN值为0 # X = np.nan_to_num(X, nan=0.0) # 创建多维度标签 label_id, label_name = self.get_or_create_multi_dimension_label( sensor_type, gas_type, concentration ) # 为所有样本分配相同的标签 y = np.full(len(X), label_id, dtype=int) print(f"已加载 {label_name['cn']} 数据: {len(X)} 样本, 特征维度: {X.shape[1]}") return X, y except Exception as e: print(f"加载数据时出错: {e}") return None, None def load_multiple_gas_data(self, file_paths, gas_types, concentrations, sensor_types): """ 加载多个气体数据并合并 参数: - file_paths: 文件路径列表 - gas_types: 气体类型列表 (如 ['acetone', 'toluene'] 等) - concentrations: 浓度值列表 (如 [20, 30] 等) - sensor_types: 传感器类型列表 (如 ['MP2', 'MP801'] 等) """ X_all = [] y_all = [] feature_dimensions = [] # 用于记录每个数据集的特征维度 for file_path, gas_type, concentration, sensor_type in zip( file_paths, gas_types, concentrations, sensor_types ): X, y = self.load_single_gas_data(file_path, gas_type, concentration, sensor_type) if X is not None and len(X) > 0: X_all.append(X) y_all.append(y) feature_dimensions.append(X.shape[1]) if not X_all: print("没有加载到有效数据") return None, None # 检查所有数据集的特征维度是否一致 unique_dimensions = np.unique(feature_dimensions) if len(unique_dimensions) > 1: print(f"警告: 检测到不同的特征维度: {unique_dimensions}") print("这可能导致合并数据时出错。请检查您的Excel文件是否具有相同的列数。") # 找出最常见的维度 from collections import Counter dimension_counts = Counter(feature_dimensions) most_common_dimension = dimension_counts.most_common(1)[0][0] print(f"最常见的特征维度是: {most_common_dimension}") # 过滤掉特征维度不匹配的数据 filtered_X_all = [] filtered_y_all = [] for i, X in enumerate(X_all): if X.shape[1] == most_common_dimension: filtered_X_all.append(X) filtered_y_all.append(y_all[i]) else: print(f"忽略特征维度不匹配的数据集: {file_paths[i]} (维度: {X.shape[1]})") if not filtered_X_all: print("没有找到特征维度匹配的数据集") return None, None X_all = filtered_X_all y_all = filtered_y_all # 合并所有数据 X_combined = np.vstack(X_all) y_combined = np.concatenate(y_all) # 检查合并后的数据中是否存在NaN值 total_nan = np.isnan(X_combined).sum() if total_nan > 0: print(f"警告: 合并后的数据中包含 {total_nan} 个NaN值,占比: {total_nan/(X_combined.size):.4f}") print(f"NaN值在样本中的分布: {np.isnan(X_combined).any(axis=1).sum()} 个样本包含NaN值") print(f"NaN值在特征中的分布: {np.isnan(X_combined).any(axis=0).sum()} 个特征包含NaN值") print(f"合并后的数据: {len(X_combined)} 样本,{len(np.unique(y_combined))} 个类别,特征维度: {X_combined.shape[1]}") return X_combined, y_combined def load_dataset(self, file_path, gas_type, concentration, sensor_type): """加载单一数据集并返回""" return self.load_single_gas_data(file_path, gas_type, concentration, sensor_type) class AlgorithmSelector: """多算法选择与训练类""" def __init__(self, use_chinese=True): # 算法名称映射 self.algorithm_names = { 'knn': {'cn': 'K-近邻算法', 'en': 'K-Nearest Neighbors'}, 'svm': {'cn': '支持向量机', 'en': 'Support Vector Machine'}, 'random_forest': {'cn': '随机森林', 'en': 'Random Forest'}, 'decision_tree': {'cn': '决策树', 'en': 'Decision Tree'}, 'neural_network': {'cn': '神经网络', 'en': 'Neural Network'} } # 算法配置 self.algorithms = { 'knn': { 'model': KNeighborsClassifier(), 'params': {'n_neighbors': 5, 'metric': 'euclidean'} }, 'svm': { 'model': SVC(), 'params': {'kernel': 'rbf', 'C': 1.0, 'probability': True} }, 'random_forest': { 'model': RandomForestClassifier(), 'params': {'n_estimators': 100, 'random_state': 42} }, 'decision_tree': { 'model': DecisionTreeClassifier(), 'params': {'max_depth': None, 'random_state': 42} }, 'neural_network': { 'model': MLPClassifier(), 'params': { 'neural_network__hidden_layer_sizes': (100, 50), # 注意前缀 'neural_network__max_iter': 500, 'neural_network__random_state': 42} } } # 算法是否需要标准化 self.needs_scaling = { 'knn': True, 'svm': True, 'random_forest': False, 'decision_tree': False, 'neural_network': True } # 是否使用中文 self.use_chinese = use_chinese def set_algorithm_params(self, algorithm_name, params): """设置算法参数""" if algorithm_name in self.algorithms: # 为Pipeline正确格式化参数名称 formatted_params = {f"{algorithm_name}__{k}": v for k, v in params.items()} self.algorithms[algorithm_name]['params'] = formatted_params else: raise ValueError(f"不支持的算法: {algorithm_name}") def train_models(self, X, y, test_size=0.2, random_state=42): """ 训练所有算法并返回结果 返回: - 包含训练好的模型及其性能的字典 """ # 检查类别数量 unique_classes = np.unique(y) num_classes = len(unique_classes) if num_classes < 2: print(f"警告: 数据集中只有 {num_classes} 个类别,某些算法可能无法训练") print(f"单一类别值: {unique_classes[0]}") # 跳过SVM算法,因为它需要至少两个类别 algorithms_to_train = [name for name in self.algorithms if name != 'svm'] print(f"由于单类别数据,将跳过 SVM 算法,仅训练: {', '.join([self.algorithm_names[name]['cn'] for name in algorithms_to_train])}") # 在单一数据集上划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state ) # 标记这是单类别数据 is_single_class_data = True else: # 在多类别数据集上划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state, stratify=y ) algorithms_to_train = list(self.algorithms.keys()) is_single_class_data = False # 检查数据类型并确保可以安全转换为数值类型 try: # 尝试将数据转换为float类型 X_train_numeric = X_train.astype(float) X_test_numeric = X_test.astype(float) # 检查训练数据中是否存在NaN值 train_nan = np.isnan(X_train_numeric).sum() if train_nan > 0: print(f"警告: 训练数据中包含 {train_nan} 个NaN值,占比: {train_nan/(X_train_numeric.size):.4f}") print(f"NaN值在训练样本中的分布: {np.isnan(X_train_numeric).any(axis=1).sum()} 个样本包含NaN值") print(f"NaN值在训练特征中的分布: {np.isnan(X_train_numeric).any(axis=0).sum()} 个特征包含NaN值") # 检查测试数据中是否存在NaN值 test_nan = np.isnan(X_test_numeric).sum() if test_nan > 0: print(f"警告: 测试数据中包含 {test_nan} 个NaN值,占比: {test_nan/(X_test_numeric.size):.4f}") print(f"NaN值在测试样本中的分布: {np.isnan(X_test_numeric).any(axis=1).sum()} 个样本包含NaN值") print(f"NaN值在测试特征中的分布: {np.isnan(X_test_numeric).any(axis=0).sum()} 个特征包含NaN值") except ValueError as e: print(f"警告: 无法将数据转换为数值类型,跳过NaN值检查: {e}") results = {} for name in algorithms_to_train: algo = self.algorithms[name] # 获取算法名称(根据是否支持中文选择) algo_name = self.algorithm_names[name]['cn'] if self.use_chinese else self.algorithm_names[name]['en'] try: print(f"\n训练 {algo_name}...") # 创建模型管道 if self.needs_scaling[name]: # 为需要标准化的算法创建包含三个步骤的Pipeline model = Pipeline([ ('imputer', SimpleImputer(strategy='mean')), # 使用均值填充缺失值 ('scaler', StandardScaler()), (name, clone(algo['model'])) ]) else: # 为不需要标准化的算法创建包含两个步骤的Pipeline model = Pipeline([ ('imputer', SimpleImputer(strategy='mean')), # 使用均值填充缺失值 (name, clone(algo['model'])) ]) # 为决策树和随机森林直接设置参数,不使用Pipeline参数设置方式 if name in ['decision_tree', 'random_forest']: # 获取算法实例 algo_instance = model.named_steps[name] # 直接设置参数 for param, value in algo['params'].items(): setattr(algo_instance, param, value) else: # 为其他算法使用Pipeline参数设置方式 model.set_params(**algo['params']) # 训练模型 model.fit(X_train, y_train) # 评估模型 train_accuracy = model.score(X_train, y_train) test_accuracy = model.score(X_test, y_test) y_pred = model.predict(X_test) print(f"训练集准确率: {train_accuracy:.4f}") print(f"测试集准确率: {test_accuracy:.4f}") print("分类报告:") print(classification_report(y_test, y_pred)) results[name] = { 'name': algo_name, 'model': model, 'train_accuracy': train_accuracy, 'test_accuracy': test_accuracy, 'y_pred': y_pred, 'X_test': X_test, 'y_test': y_test, 'unique_labels': np.unique(y_test), 'is_single_class': is_single_class_data } except Exception as e: print(f"训练 {algo_name} 时发生错误: {e}") results[name] = { 'name': algo_name, 'error': str(e), 'is_single_class': is_single_class_data } # 为跳过的SVM算法添加结果记录 if 'svm' not in algorithms_to_train: svm_name = self.algorithm_names['svm']['cn'] if self.use_chinese else self.algorithm_names['svm']['en'] results['svm'] = { 'name': svm_name, 'error': "由于单类别数据,跳过SVM算法", 'is_single_class': is_single_class_data } return results def compare_algorithms(self, results): """比较不同算法的性能""" # 过滤掉训练失败的算法 valid_results = {name: result for name, result in results.items() if 'test_accuracy' in result} if not valid_results: print("没有算法成功训练,无法生成比较。") return None names = [valid_results[name]['name'] for name in valid_results] accuracies = [valid_results[name]['test_accuracy'] for name in valid_results] plt.figure(figsize=(12, 6)) bars = plt.bar(names, accuracies, color='skyblue') # 根据是否支持中文选择标题 title = "不同算法的测试集准确率比较" if self.use_chinese else "Comparison of Test Set Accuracies for Different Algorithms" x_label = "算法" if self.use_chinese else "Algorithm" y_label = "准确率" if self.use_chinese else "Accuracy" plt.ylim(0, 1.05) plt.title(title) plt.xlabel(x_label) plt.ylabel(y_label) # 添加数值标签 for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{height:.4f}', ha='center', va='bottom') plt.xticks(rotation=45, ha='right') plt.tight_layout() return plt def plot_confusion_matrix(self, results, gas_data_loader, use_chinese=True, rotate_labels=45, fig_width=12, fig_height=10, font_size=10): """ 绘制混淆矩阵 参数: - results: 包含算法结果的字典 - gas_data_loader: 气体数据加载器实例 - use_chinese: 是否使用中文 - rotate_labels: 标签旋转角度,默认为45度 - fig_width: 形的宽度,默认为12 - fig_height: 形的高度,默认为10 - font_size: 字体大小,默认为10 """ # 过滤掉训练失败的算法 valid_results = {name: result for name, result in results.items() if 'test_accuracy' in result} if not valid_results: print("没有算法成功训练,无法生成混淆矩阵。") return None # 获取所有算法中出现的唯一标签 all_unique_labels = set() for name, result in valid_results.items(): all_unique_labels.update(result['unique_labels']) all_unique_labels = sorted(list(all_unique_labels)) # 创建标签名称映射 label_names = [] for label in all_unique_labels: # 尝试查找对应的多维度标签名称 label_name = None for key, label_id in gas_data_loader.multi_dimension_labels.items(): if label_id == label: # 获取标签名称而不是标签ID label_name = gas_data_loader.get_or_create_multi_dimension_label( key.split('_')[0], # 传感器类型 key.split('_')[1], # 气体类型 int(key.split('_')[2].replace('ppm', '')) # 浓度值 )[1] # 获取第二个返回值,即标签名称字典 break # 如果找到,使用对应的标签名称 if label_name and isinstance(label_name, dict): if use_chinese: label_names.append(label_name.get('cn', f"类别 {label}")) else: label_names.append(label_name.get('en', f"Class {label}")) else: # 如果没有找到,使用默认标签名称 label_names.append(f"类别 {label}" if use_chinese else f"Class {label}") for name, result in valid_results.items(): plt.figure(figsize=(fig_width, fig_height)) cm = confusion_matrix(result['y_test'], result['y_pred'], labels=all_unique_labels) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names) disp.plot(cmap=plt.cm.Blues) # 根据是否支持中文选择标题 title = f"{result['name']} 混淆矩阵" if use_chinese else f"{result['name']} Confusion Matrix" # 如果是单类别数据,添加说明 if result.get('is_single_class', False): title += " (单类别数据)" plt.title(title) # 旋转x轴标签 plt.xticks(rotation=rotate_labels, ha='right', rotation_mode='anchor', fontsize=font_size) plt.yticks(fontsize=font_size) plt.tight_layout() return plt def detect_dataset_type(dataset_path): """ 自动检测数据集类型:单一数据集或多数据集文件夹 参数: - dataset_path: 数据集路径 返回: - dataset_type: 'single' 或 'multiple' - file_paths: 文件路径列表 - gas_types: 气体类型列表 - concentrations: 浓度值列表 - sensor_types: 传感器类型列表 """ # 检查路径是否存在 if not os.path.exists(dataset_path): raise FileNotFoundError(f"路径不存在: {dataset_path}") # 检查是文件还是文件夹 if os.path.isfile(dataset_path): # 单一文件处理 file_paths = [dataset_path] # 从文件名提取传感器类型、气体类型和浓度 file_name = os.path.basename(dataset_path) sensor_type = extract_sensor_type(file_name) gas_type = extract_gas_type(file_name) concentration = extract_concentration(file_name) gas_types = [gas_type] concentrations = [concentration] sensor_types = [sensor_type] print(f"检测到单一数据集: {file_name}") print(f"传感器类型: {sensor_type}, 气体类型: {gas_type}, 浓度: {concentration}ppm") return 'single', file_paths, gas_types, concentrations, sensor_types elif os.path.isdir(dataset_path): # 文件夹处理 - 查找所有Excel文件 excel_files = [f for f in os.listdir(dataset_path) if f.endswith(('.xlsx', '.xls'))] if not excel_files: raise ValueError(f"文件夹中没有找到Excel文件: {dataset_path}") file_paths = [] gas_types = [] concentrations = [] sensor_types = [] for file in excel_files: file_path = os.path.join(dataset_path, file) file_paths.append(file_path) # 从文件名提取传感器类型、气体类型和浓度 sensor_type = extract_sensor_type(file) gas_type = extract_gas_type(file) concentration = extract_concentration(file) gas_types.append(gas_type) concentrations.append(concentration) sensor_types.append(sensor_type) print(f"找到数据集文件: {file}") print(f"传感器类型: {sensor_type}, 气体类型: {gas_type}, 浓度: {concentration}ppm") print(f"总共找到 {len(file_paths)} 个数据集文件") return 'multiple', file_paths, gas_types, concentrations, sensor_types else: raise ValueError(f"无法识别的路径: {dataset_path}") def extract_sensor_type(file_name): """从文件名提取传感器类型""" # 定义传感器类型的正则表达式模式 sensor_patterns = { 'MP2': r'(^MP2[^a-zA-Z0-9]|MP2$)', 'MP3B': r'(^MP3B[^a-zA-Z0-9]|MP3B$)', 'MP503': r'(^MP503[^a-zA-Z0-9]|MP503$)', 'MP801': r'(^MP801[^a-zA-Z0-9]|MP801$)', 'MQ2': r'(^MQ2[^a-zA-Z0-9]|MQ2$)', 'MQ7B': r'(^MQ7B[^a-zA-Z0-9]|MQ7B$)' } # 转换为大写以提高匹配率 file_name_upper = file_name.upper() # 尝试匹配传感器类型 for sensor_type, pattern in sensor_patterns.items(): if re.search(pattern, file_name_upper): return sensor_type # 如果没有匹配到,返回默认值 print(f"警告: 无法从文件名 '{file_name}' 中提取传感器类型,使用默认值 'MP2'") return 'MP2' def extract_gas_type(file_name): """从文件名提取气体类型""" # 定义基础气体类型的中英文名称映射 gas_name_mapping = { 'bingtong': 'acetone', '丙酮': 'acetone', 'jiaben': 'toluene', '甲苯': 'toluene', 'jiachun': 'methanol', '甲醇': 'methanol', 'jiaquan': 'formaldehyde', '甲醛': 'formaldehyde', 'yichun': 'ethanol', '乙醇': 'ethanol' } # 去除文件扩展名 file_name_without_ext = os.path.splitext(file_name)[0] # 按照固定格式"传感器_气体名称_浓度"分割文件名 parts = file_name_without_ext.split('_') # 确保有足够的部分 if len(parts) < 3: print(f"警告: 文件名格式不符合预期: {file_name}") return 'acetone' # 获取气体名称部分 gas_name_part = parts[1] # 检查是否为混合气体 if '+' in gas_name_part or '+' in gas_name_part: # 处理混合气体 # 统一分隔符 gas_name_part = gas_name_part.replace('+', '+') gas_components = gas_name_part.split('+') # 转换为标准气体名称 standard_gas_names = [] for component in gas_components: # 先尝试中文名称映射 standard_name = gas_name_mapping.get(component, None) if standard_name: standard_gas_names.append(standard_name) else: # 如果是英文名称,直接添加 if component.lower() in ['acetone', 'toluene', 'methanol', 'formaldehyde', 'ethanol']: standard_gas_names.append(component.lower()) else: print(f"警告: 无法识别的气体成分: {component}") # 按字母顺序排序以确保一致性 standard_gas_names.sort() # 组合成混合气体名称 if len(standard_gas_names) > 1: return '+'.join(standard_gas_names) elif len(standard_gas_names) == 1: return standard_gas_names[0] # 处理单一气体 # 先尝试中文名称映射 standard_name = gas_name_mapping.get(gas_name_part, None) if standard_name: return standard_name # 如果是英文名称,直接返回小写形式 if gas_name_part.lower() in ['acetone', 'toluene', 'methanol', 'formaldehyde', 'ethanol']: return gas_name_part.lower() # 如果没有匹配到,返回默认值 print(f"警告: 无法从文件名 '{file_name}' 中提取气体类型,使用默认值 'acetone'") return 'acetone' def extract_concentration(file_name): """从文件名提取浓度值""" # 去除文件扩展名 file_name_without_ext = os.path.splitext(file_name)[0] # 按照固定格式"传感器_气体名称_浓度"分割文件名 parts = file_name_without_ext.split('_') # 确保有足够的部分 if len(parts) < 3: print(f"警告: 文件名格式不符合预期: {file_name}") return 20 # 获取浓度部分 concentration_part = parts[2] # 提取数字部分 match = re.search(r'(\d+)', concentration_part) if match: return int(match.group(1)) # 如果没有匹配到,返回默认值 print(f"警告: 无法从文件名 '{file_name}' 中提取浓度值,使用默认值 20ppm") return 20 def main(): """主函数""" # 检查中文字体支持 chinese_supported = check_chinese_font_support() # 创建数据加载器 data_loader = GasSensorDataAnalyzer() # 定义数据集路径 dataset_path = r"C:\Users\Cong\Desktop\作业\项目\六通道2混合\2_MP2" try: # 自动检测数据集类型 dataset_type, file_paths, gas_types, concentrations, sensor_types = detect_dataset_type(dataset_path) # 根据检测结果加载数据 if dataset_type == 'single': # 加载单一数据集 X, y = data_loader.load_dataset(file_paths[0], gas_types[0], concentrations[0], sensor_types[0]) else: # 加载多个数据集并合并 X, y = data_loader.load_multiple_gas_data(file_paths, gas_types, concentrations, sensor_types) if X is None or len(X) == 0: print("No valid data available for training. Please check file paths and formats.") return print(f"加载的数据集总样本数: {len(X)}") print(f"数据集中的类别数量: {len(np.unique(y))}") # 创建算法选择器,根据中文字体支持情况决定是否使用中文 selector = AlgorithmSelector(use_chinese=chinese_supported) # 自定义参数配置示例 selector.set_algorithm_params('knn', {'n_neighbors': 3, 'metric': 'manhattan'}) selector.set_algorithm_params('svm', {'C': 0.8, 'kernel': 'linear'}) selector.set_algorithm_params('neural_network', {'hidden_layer_sizes': (150, 75)}) # 训练所有算法 results = selector.train_models(X, y) # 比较算法性能 plt1 = selector.compare_algorithms(results) if plt1: plt1.savefig('algorithm_comparison.png') plt1.close() # 绘制混淆矩阵 plt2 = selector.plot_confusion_matrix(results, data_loader, use_chinese=chinese_supported, rotate_labels=45,fig_width=20, fig_height=20, font_size=8) if plt2: plt2.savefig('confusion_matrix.png') plt2.close() print("\n算法比较结果已保存为 'algorithm_comparison.png'") print("混淆矩阵已保存为 'confusion_matrix.png'") except Exception as e: print(f"程序执行过程中发生错误: {e}") if __name__ == "__main__": main()还有一个是tempcoderunnerfile.py文件:@app.route('/upload', methods=['POST']) def upload_file(): """处理文件上传""" if 'files' not in request.files: return jsonify({'error': 'No file part'}), 400 files = request.files.getlist('files') gas_type = request.form.get('gas_type', 'acetone') concentration = int(request.form.get('concentration', 20)) if not files or files[0].filename == '': return jsonify({'error': 'No selected file'}), 400 datasets = [] for file in files: if file and allowed_file(file.filename): # 保存临时文件 file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) file.save(file_path) # 加载数据 data = data_loader.load_single_gas_data(file_path, gas_type, concentration) datasets.append(data) # 删除临时文件 os.remove(file_path) # 合并数据集 X, y = data_loader.combine_datasets(datasets) if X is None or len(X) == 0: return jsonify({'error': 'No valid data loaded'}), 400 # 保存合并后的数据 df = pd.DataFrame(X) df['label'] = y file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'temp_data.xlsx') df.to_excel(file_path, index=False) return jsonify({ 'status': 'success', 'sample_count': len(X), 'feature_count': X.shape[1], 'gas_type': gas_type, 'concentration': concentration })请根据这两个文件重修修改app.py文件,确保他在algorithmselection呈现的页面中当我选择两个文件进入的时候能够分析数据
06-22
import sys import subprocess import zipfile import pkg_resources import requests # 检查并安装缺失的依赖 required = { 'torch', 'torchvision', 'numpy', 'matplotlib', 'tqdm', 'requests', 'pillow', 'scikit-learn', 'pyqt5', 'torchsummary' # 添加torchsummary } installed = {pkg.key for pkg in pkg_resources.working_set} missing = required - installed if missing: print(f"安装缺失的依赖: {', '.join(missing)}") python = sys.executable subprocess.check_call([python, '-m', 'pip', 'install', *missing]) # 现在导入其他模块 import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, random_split from torchvision import datasets, transforms, models import numpy as np import matplotlib.pyplot as plt import os import shutil from PIL import Image from tqdm import tqdm import matplotlib from matplotlib import font_manager import json from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # PyQt5相关导入 from PyQt5.QtWidgets import (QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QScrollArea, QFileDialog, QMessageBox, QTextEdit) from PyQt5.QtGui import QPixmap from PyQt5.QtCore import Qt, QObject, pyqtSignal import threading import time # 导入torchsummary from torchsummary import summary # 设置中文字体支持 try: plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False except: try: font_url = "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansSC-Regular.otf" font_path = "NotoSansSC-Regular.otf" if not os.path.exists(font_path): response = requests.get(font_url) with open(font_path, 'wb') as f: f.write(response.content) font_prop = font_manager.FontProperties(fname=font_path) plt.rcParams['font.family'] = font_prop.get_name() except: print("警告: 无法设置中文字体") matplotlib.use('Agg') # 第二部分:下载并设置数据集 def download_and_extract_dataset(): base_dir = "data" data_path = os.path.join(base_dir, "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') os.makedirs(train_folder, exist_ok=True) os.makedirs(test_folder, exist_ok=True) # 检查数据集是否完整 cat_files = [f for f in os.listdir(train_folder) if f.startswith('cat')] dog_files = [f for f in os.listdir(train_folder) if f.startswith('dog')] if len(cat_files) > 1000 and len(dog_files) > 1000: print("数据集已存在,跳过下载") return print("正在下载数据集...") dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip" try: zip_path = os.path.join(base_dir, "catsdogs.zip") # 下载文件 if not os.path.exists(zip_path): response = requests.get(dataset_url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(zip_path, 'wb') as f, tqdm( desc="下载进度", total=total_size, unit='B', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): size = f.write(data) bar.update(size) print("下载完成,正在解压...") # 解压文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(base_dir) print("数据集解压完成!") # 移动文件 extracted_dir = os.path.join(base_dir, "PetImages") # 移动猫片 cat_source = os.path.join(extracted_dir, "Cat") for file in os.listdir(cat_source): src = os.path.join(cat_source, file) dst = os.path.join(train_folder, f"cat.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 移动狗片 dog_source = os.path.join(extracted_dir, "Dog") for file in os.listdir(dog_source): src = os.path.join(dog_source, file) dst = os.path.join(train_folder, f"dog.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 创建测试集(从训练集中抽取20%) train_files = os.listdir(train_folder) np.random.seed(42) test_files = np.random.choice(train_files, size=int(len(train_files) * 0.2), replace=False) for file in test_files: src = os.path.join(train_folder, file) dst = os.path.join(test_folder, file) if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 清理临时文件 if os.path.exists(extracted_dir): shutil.rmtree(extracted_dir) if os.path.exists(zip_path): os.remove(zip_path) print( f"数据集设置完成!训练集: {len(os.listdir(train_folder))} 张片, 测试集: {len(os.listdir(test_folder))} 张片") except Exception as e: print(f"下载或设置数据集时出错: {str(e)}") print("请手动下载数据集并解压到 data/dogs-vs-cats 目录") print("下载地址: https://www.microsoft.com/en-us/download/details.aspx?id=54765") # 下载并解压数据集 download_and_extract_dataset() # 第三部分:自定义数据集 class DogsVSCats(Dataset): def __init__(self, data_dir, transform=None): self.image_paths = [] self.labels = [] for file in os.listdir(data_dir): if file.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(data_dir, file) try: # 验证片完整性 with Image.open(img_path) as img: img.verify() self.image_paths.append(img_path) # 根据文件名设置标签 if file.startswith('cat'): self.labels.append(0) elif file.startswith('dog'): self.labels.append(1) else: # 对于无法识别的文件,默认设为猫 self.labels.append(0) except (IOError, SyntaxError) as e: print(f"跳过损坏片: {img_path} - {str(e)}") if not self.image_paths: print(f"错误: 在 {data_dir} 中没有找到有效片!") for i in range(10): img_path = os.path.join(data_dir, f"example_{i}.jpg") img = Image.new('RGB', (224, 224), color=(i * 25, i * 25, i * 25)) img.save(img_path) self.image_paths.append(img_path) self.labels.append(0 if i % 2 == 0 else 1) print(f"已创建 {len(self.image_paths)} 个示例片") self.transform = transform or transforms.Compose([ transforms.Resize((150, 150)), # 修改为150x150以匹配CNN输入 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): try: image = Image.open(self.image_paths[idx]).convert('RGB') except Exception as e: print(f"无法加载片: {self.image_paths[idx]}, 使用占位符 - {str(e)}") image = Image.new('RGB', (150, 150), color=(100, 100, 100)) image = self.transform(image) label = torch.tensor(self.labels[idx], dtype=torch.long) return image, label # 第六部分:定义自定义CNN模型(添加额外的Dropout层) class CatDogCNN(nn.Module): def __init__(self): super(CatDogCNN, self).__init__() # 卷积层1: 输入3通道(RGB), 输出32通道, 卷积核3x3 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) # 卷积层2: 输入32通道, 输出64通道, 卷积核3x3 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 卷积层3: 输入64通道, 输出128通道, 卷积核3x3 self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 卷积层4: 输入128通道, 输出256通道, 卷积核3x3 self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # 最大池化层 self.pool = nn.MaxPool2d(2, 2) # 全连接层 self.fc1 = nn.Linear(256 * 9 * 9, 512) # 输入尺寸计算: 150 -> 75 -> 37 -> 18 -> 9 self.fc2 = nn.Linear(512, 2) # 输出2个类别 (猫和狗) # Dropout防止过拟合(添加额外的Dropout层) self.dropout1 = nn.Dropout(0.5) # 第一个Dropout层 self.dropout2 = nn.Dropout(0.5) # 新添加的第二个Dropout层 def forward(self, x): # 卷积层1 + ReLU + 池化 x = self.pool(F.relu(self.conv1(x))) # 卷积层2 + ReLU + 池化 x = self.pool(F.relu(self.conv2(x))) # 卷积层3 + ReLU + 池化 x = self.pool(F.relu(self.conv3(x))) # 卷积层4 + ReLU + 池化 x = self.pool(F.relu(self.conv4(x))) # 展平特征 x = x.view(-1, 256 * 9 * 9) # 全连接层 + Dropout x = self.dropout1(F.relu(self.fc1(x))) # 添加第二个Dropout层 x = self.dropout2(x) # 输出层 x = self.fc2(x) return x # 第七部分:模型训练和可视化 class Trainer: def __init__(self, model, train_loader, val_loader): self.train_loader = train_loader self.val_loader = val_loader self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") self.model = model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) self.criterion = nn.CrossEntropyLoss() # 使用兼容性更好的调度器设置(移除了 verbose 参数) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='max', factor=0.1, patience=2) # 记录指标 self.train_losses = [] self.train_accuracies = [] self.val_losses = [] self.val_accuracies = [] def train(self, num_epochs): best_accuracy = 0.0 for epoch in range(num_epochs): # 训练阶段 self.model.train() running_loss = 0.0 correct = 0 total = 0 train_bar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [训练]") for images, labels in train_bar: images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / total train_acc = correct / total train_bar.set_postfix(loss=train_loss, acc=train_acc) # 计算训练指标 epoch_train_loss = running_loss / total epoch_train_acc = correct / total self.train_losses.append(epoch_train_loss) self.train_accuracies.append(epoch_train_acc) # 验证阶段 val_loss, val_acc = self.validate() self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) # 更新学习率 self.scheduler.step(val_acc) # 保存最佳模型 if val_acc > best_accuracy: best_accuracy = val_acc torch.save(self.model.state_dict(), 'best_cnn_model.pth') print(f"保存最佳模型,验证准确率: {best_accuracy:.4f}") # 打印epoch结果 print(f"Epoch {epoch + 1}/{num_epochs} | " f"训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.4f} | " f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.4f}") # 训练完成后可视化结果 self.visualize_training_results() def validate(self): self.model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): val_bar = tqdm(self.val_loader, desc="[验证]") for images, labels in val_bar: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) loss = self.criterion(outputs, labels) running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_loss = running_loss / total val_acc = correct / total val_bar.set_postfix(loss=val_loss, acc=val_acc) return running_loss / total, correct / total def visualize_training_results(self): """可视化训练和验证的准确率与损失""" epochs = range(1, len(self.train_accuracies) + 1) # 创建准确率表 plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) # 创建损失表 plt.subplot(1, 2, 2) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('training_visualization.png') print("训练结果可视化表已保存为 training_visualization.png") # 单独保存准确率表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.savefig('accuracy_curve.png') print("准确率曲线已保存为 accuracy_curve.png") # 单独保存损失表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.savefig('loss_curve.png') print("损失曲线已保存为 loss_curve.png") # 保存训练结果 results = { 'epochs': list(epochs), 'train_losses': self.train_losses, 'train_accuracies': self.train_accuracies, 'val_losses': self.val_losses, 'val_accuracies': self.val_accuracies } with open('training_results.json', 'w') as f: json.dump(results, f) print("训练结果已保存为 training_results.json") # 像处理类 class ImageProcessor(QObject): result_signal = pyqtSignal(str, str) # 信号:filename, result def __init__(self, model, device, filename): super().__init__() self.model = model self.device = device self.filename = filename self.transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def process_image(self): try: # 加载像 image = Image.open(self.filename).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # 模型预测 self.model.eval() with torch.no_grad(): output = self.model(image_tensor) probabilities = F.softmax(output, dim=1) _, predicted = torch.max(output, 1) # 获取猫和狗的置信度 cat_prob = probabilities[0][0].item() dog_prob = probabilities[0][1].item() # 确定结果和置信度 result = "猫" if predicted.item() == 0 else "狗" confidence = cat_prob if result == "猫" else dog_prob # 格式化输出结果 formatted_result = f"{result} ({confidence * 100:.1f}%置信度)" self.result_signal.emit(self.filename, formatted_result) except Exception as e: self.result_signal.emit(self.filename, f"处理错误: {str(e)}") # 主应用窗口 class CatDogClassifierApp(QWidget): def __init__(self, model, device): super().__init__() self.setWindowTitle("猫狗识别系统") self.setGeometry(100, 100, 1000, 700) self.model = model self.device = device self.initUI() self.image_processors = [] def initUI(self): # 主布局 main_layout = QVBoxLayout() # 标题 title = QLabel("猫狗识别系统") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 10px;") main_layout.addWidget(title) # 按钮区域 button_layout = QHBoxLayout() self.upload_button = QPushButton("上传像") self.upload_button.setStyleSheet("font-size: 16px; padding: 10px;") self.upload_button.clicked.connect(self.uploadImage) button_layout.addWidget(self.upload_button) self.batch_process_button = QPushButton("批量处理") self.batch_process_button.setStyleSheet("font-size: 16px; padding: 10px;") self.batch_process_button.clicked.connect(self.batchProcess) button_layout.addWidget(self.batch_process_button) self.clear_button = QPushButton("清除所有") self.clear_button.setStyleSheet("font-size: 16px; padding: 10px;") self.clear_button.clicked.connect(self.clearAll) button_layout.addWidget(self.clear_button) self.results_button = QPushButton("查看训练结果") self.results_button.setStyleSheet("font-size: 16px; padding: 10px;") self.results_button.clicked.connect(self.showTrainingResults) button_layout.addWidget(self.results_button) # 添加查看模型结构按钮 self.model_summary_button = QPushButton("查看模型结构") self.model_summary_button.setStyleSheet("font-size: 16px; padding: 10px;") self.model_summary_button.clicked.connect(self.showModelSummary) button_layout.addWidget(self.model_summary_button) main_layout.addLayout(button_layout) # 状态标签 self.status_label = QLabel("就绪") self.status_label.setStyleSheet("font-size: 14px; color: #666; margin: 5px;") main_layout.addWidget(self.status_label) # 像预览区域 self.preview_area = QScrollArea() self.preview_area.setWidgetResizable(True) self.preview_area.setStyleSheet("background-color: #f0f0f0;") self.preview_widget = QWidget() self.preview_layout = QHBoxLayout() self.preview_layout.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.preview_widget.setLayout(self.preview_layout) self.preview_area.setWidget(self.preview_widget) main_layout.addWidget(self.preview_area) # 底部信息 info_label = QLabel("基于卷积神经网络(CNN)的猫狗识别系统 | 支持上传单张或多张片") info_label.setAlignment(Qt.AlignCenter) info_label.setStyleSheet("font-size: 12px; color: #888; margin: 10px;") main_layout.addWidget(info_label) self.setLayout(main_layout) def uploadImage(self): self.status_label.setText("正在选择像...") filename, _ = QFileDialog.getOpenFileName( self, "选择像", "", "像文件 (*.png *.jpg *.jpeg)" ) if filename: self.status_label.setText(f"正在处理: {os.path.basename(filename)}") self.displayImage(filename) def batchProcess(self): self.status_label.setText("正在选择多张像...") filenames, _ = QFileDialog.getOpenFileNames( self, "选择多张像", "", "像文件 (*.png *.jpg *.jpeg)" ) if filenames: self.status_label.setText(f"正在批量处理 {len(filenames)} 张像...") for filename in filenames: self.displayImage(filename) def displayImage(self, filename): if not os.path.isfile(filename): QMessageBox.warning(self, "警告", "文件路径不安全或文件不存在") self.status_label.setText("错误: 文件不存在") return # 检查是否已存在相同文件 for i in reversed(range(self.preview_layout.count())): item = self.preview_layout.itemAt(i) if item.widget() and item.widget().objectName().startswith(f"container_{filename}"): widget_to_remove = item.widget() self.preview_layout.removeWidget(widget_to_remove) widget_to_remove.deleteLater() # 创建像容器 container = QWidget() container.setObjectName(f"container_{filename}") container.setStyleSheet(""" background-color: white; border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin: 5px; """) container.setFixedSize(300, 350) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(5, 5, 5, 5) container_layout.setSpacing(5) # 显示文件名 filename_label = QLabel(os.path.basename(filename)) filename_label.setStyleSheet("font-size: 12px; color: #555;") filename_label.setAlignment(Qt.AlignCenter) container_layout.addWidget(filename_label) # 像预览 pixmap = QPixmap(filename) if pixmap.width() > 280 or pixmap.height() > 200: pixmap = pixmap.scaled(280, 200, Qt.KeepAspectRatio, Qt.SmoothTransformation) preview_label = QLabel(container) preview_label.setPixmap(pixmap) preview_label.setAlignment(Qt.AlignCenter) preview_label.setFixedSize(280, 200) preview_label.setStyleSheet("border: 1px solid #eee;") container_layout.addWidget(preview_label) # 结果标签 result_label = QLabel("识别中...", container) result_label.setObjectName(f"result_{filename}") result_label.setAlignment(Qt.AlignCenter) result_label.setStyleSheet("font-size: 16px; font-weight: bold; padding: 5px;") container_layout.addWidget(result_label) # 删除按钮 delete_button = QPushButton("删除", container) delete_button.setObjectName(f"button_{filename}") delete_button.setStyleSheet(""" QPushButton { background-color: #ff6b6b; color: white; border: none; border-radius: 3px; padding: 5px; } QPushButton:hover { background-color: #ff5252; } """) delete_button.clicked.connect(lambda _, fn=filename: self.deleteImage(fn)) container_layout.addWidget(delete_button) # 添加到预览区域 self.preview_layout.addWidget(container) # 创建并启动像处理线程 processor = ImageProcessor(self.model, self.device, filename) processor.result_signal.connect(self.updateUIWithResult) threading.Thread(target=processor.process_image).start() self.image_processors.append(processor) # 限制最大处理数量 if self.preview_layout.count() > 20: QMessageBox.warning(self, "警告", "最多只能同时处理20张像") self.image_processors.clear() def deleteImage(self, filename): container_name = f"container_{filename}" container = self.findChild(QWidget, container_name) if container: self.preview_layout.removeWidget(container) container.deleteLater() self.status_label.setText(f"已删除: {os.path.basename(filename)}") def updateUIWithResult(self, filename, result): container = self.findChild(QWidget, f"container_{filename}") if container: result_label = container.findChild(QLabel, f"result_{filename}") if result_label: # 根据结果设置颜色 if "猫" in result: result_label.setStyleSheet("color: #1a73e8; font-size: 16px; font-weight: bold;") elif "狗" in result: result_label.setStyleSheet("color: #e91e63; font-size: 16px; font-weight: bold;") else: result_label.setStyleSheet("color: #f57c00; font-size: 16px; font-weight: bold;") result_label.setText(result) self.status_label.setText(f"完成识别: {os.path.basename(filename)} -> {result}") def clearAll(self): # 删除所有像容器 while self.preview_layout.count(): item = self.preview_layout.takeAt(0) widget = item.widget() if widget is not None: widget.deleteLater() self.image_processors = [] self.status_label.setText("已清除所有像") def showTrainingResults(self): """显示训练结果可视化表""" if not os.path.exists('training_visualization.png'): QMessageBox.information(self, "提示", "训练结果可视化表尚未生成") return try: # 创建结果展示窗口 results_window = QWidget() results_window.setWindowTitle("训练结果可视化") results_window.setGeometry(200, 200, 1200, 800) layout = QVBoxLayout() # 标题 title = QLabel("模型训练结果可视化") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 综合表 layout.addWidget(QLabel("训练和验证准确率/损失:")) pixmap1 = QPixmap('training_visualization.png') label1 = QLabel() label1.setPixmap(pixmap1.scaled(1000, 500, Qt.KeepAspectRatio, Qt.SmoothTransformation)) layout.addWidget(label1) # 水平布局用于两个表 h_layout = QHBoxLayout() # 准确率表 vbox1 = QVBoxLayout() vbox1.addWidget(QLabel("准确率曲线:")) pixmap2 = QPixmap('accuracy_curve.png') label2 = QLabel() label2.setPixmap(pixmap2.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox1.addWidget(label2) h_layout.addLayout(vbox1) # 损失表 vbox2 = QVBoxLayout() vbox2.addWidget(QLabel("损失曲线:")) pixmap3 = QPixmap('loss_curve.png') label3 = QLabel() label3.setPixmap(pixmap3.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox2.addWidget(label3) h_layout.addLayout(vbox2) layout.addLayout(h_layout) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(results_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) results_window.setLayout(layout) results_window.show() except Exception as e: QMessageBox.critical(self, "错误", f"加载训练结果时出错: {str(e)}") def showModelSummary(self): """显示模型结构摘要""" # 创建摘要展示窗口 summary_window = QWidget() summary_window.setWindowTitle("模型结构摘要") summary_window.setGeometry(200, 200, 800, 600) layout = QVBoxLayout() # 标题 title = QLabel("模型各层参数状况") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 创建文本编辑框显示摘要 summary_text = QTextEdit() summary_text.setReadOnly(True) summary_text.setStyleSheet("font-family: monospace; font-size: 12px;") # 获取模型摘要 try: # 使用StringIO捕获summary的输出 from io import StringIO import sys # 重定向标准输出 original_stdout = sys.stdout sys.stdout = StringIO() # 生成模型摘要 summary(self.model, input_size=(3, 150, 150), device=self.device.type) # 获取捕获的输出 summary_output = sys.stdout.getvalue() # 恢复标准输出 sys.stdout = original_stdout # 显示摘要 summary_text.setPlainText(summary_output) except Exception as e: summary_text.setPlainText(f"生成模型摘要时出错: {str(e)}") layout.addWidget(summary_text) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(summary_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) summary_window.setLayout(layout) summary_window.show() # 程序入口点 if __name__ == "__main__": # 设置数据集路径 data_path = os.path.join("data", "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') # 检查是否已有训练好的模型 model_path = "catdog_cnn_model_with_extra_dropout.pth" # 修改模型名称以反映更改 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 创建模型实例(使用添加了额外Dropout层的新模型) model = CatDogCNN() if os.path.exists(model_path): print("加载已训练的模型...") model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() print("模型加载完成") else: print("未找到训练好的模型,开始训练新模型...") # 创建完整训练集和测试集(使用数据增强) # 训练集使用增强后的transform train_transform = transforms.Compose([ transforms.RandomRotation(15), # 随机旋转15度 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.Resize((150, 150)), transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机调整亮度和对比度 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 验证集和测试集使用基础transform(不需要增强) base_transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) full_train_dataset = DogsVSCats(train_folder, transform=train_transform) test_dataset = DogsVSCats(test_folder, transform=base_transform) # 划分训练集和验证集 (80% 训练, 20% 验证) train_size = int(0.8 * len(full_train_dataset)) val_size = len(full_train_dataset) - train_size gen = torch.Generator().manual_seed(42) train_dataset, val_dataset = random_split( full_train_dataset, [train_size, val_size], generator=gen ) # 创建数据加载器 batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # 训练模型 trainer = Trainer(model, train_loader, val_loader) num_epochs = 15 print(f"开始训练(带额外Dropout层和数据增强),共 {num_epochs} 个epoch...") trainer.train(num_epochs) # 保存最终模型 torch.save(model.state_dict(), model_path) print(f"模型已保存为 {model_path}") # 输出模型各层的参数状况 print("\n模型各层参数状况:") summary(model, input_size=(3, 150, 150), device=device.type) # 启动应用程序 app = QApplication(sys.argv) window = CatDogClassifierApp(model, device) window.show() sys.exit(app.exec_())对此代码进行优化
最新发布
06-26
import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay from sklearn.pipeline import make_pipeline # 选择5个分类器 selected_classifiers = { "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42), "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42), "XGBoost": XGBClassifier(use_label_encoder=False, eval_metric='logloss'), "SVM": SVC(kernel='rbf', probability=True, random_state=42), "Neural Network": MLPClassifier(hidden_layer_sizes=(50,), max_iter=1000, random_state=42) } # 数据准备 df = pd.read_csv('credictcard-reduced.csv') X = df.drop(['Time', 'Class'], axis=1) # 移除时间和标签列 y = df['Class'] # 划分数据集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, stratify=y, random_state=42 ) # 标准化处理 scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 创建可视化画布 fig, axes = plt.subplots(2, 5, figsize=(25, 10)) plt.subplots_adjust(wspace=0.4, hspace=0.4) # 训练和评估模型 for idx, (name, clf) in enumerate(selected_classifiers.items()): # 训练模型 clf.fit(X_train, y_train) y_pred = clf.predict(X_test) # 计算评估指标 report = classification_report(y_test, y_pred, output_dict=True) metrics_df = pd.DataFrame(report).transpose() # 绘制混淆矩阵 cm = confusion_matrix(y_test, y_pred) disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp.plot(ax=axes[0, idx], cmap='Blues') axes[0, idx].set_title(f'{name}\nConfusion Matrix') # 显示指标表格 cell_text = [[f"{metrics_df.loc['1']['precision']:.2f}", f"{metrics_df.loc['1']['recall']:.2f}", f"{metrics_df.loc['1']['f1-score']:.2f}"]] table = axes[1, idx].table(cellText=cell_text, colLabels=['Precision', 'Recall', 'F1'], loc='center', cellLoc='center') table.set_fontsize(14) table.scale(1, 2) axes[1, idx].axis('off') axes[1, idx].set_title('Class 1 Metrics') plt.show() # 输出详细评估报告 print("\n\033[1m综合性能报告:\033[0m") for name, clf in selected_classifiers.items(): y_pred = clf.predict(X_test) print(f"\n\033[1m{name}\033[0m") print(classification_report(y_test, y_pred, target_names=['0', '1'])) 将这一段的代码里面的XGBoost 改成decision tree ,jiang svm改成 adaboots,并且增加之前没有的from sklearn ... 的没有预先导入的内容
03-13
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

广阔天地,大有可为

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值