matplotlib报错ValueError: height and width must be > 0解决方案

博客围绕matplotlib报错ValueError: height and width must be > 0展开。指出有些远程软件分配虚拟屏幕,matplotlib无物理屏幕时,plt.show()会报错。并给出解决方案,有物理屏幕设单屏显示,无屏幕服务器修改代码,保存绘图为图片。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

**

matplotlib报错ValueError: height and width must be > 0解决方案

**

问题案例

import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]
plt.plot(x, y)
plt.show()

运行以上画图代码,报错ValueError: height and width must be > 0

问题原因

有些远程软件自动会分配虚拟屏幕,matplotlib没有物理屏幕,plt.show()会报错ValueError: height and width must be > 0

解决方案

1、有物理屏幕的显示器,设置为单屏显示,以ubantu为例
在这里插入图片描述
2、无屏幕的服务器,修改代码,加上matplotlib.use(‘Agg’)和plt.savefig(‘./test.jpg’),去掉plt.show(),将绘图保存为图片。
代码修改如下:

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]
plt.plot(x, y)
plt.savefig('./test.jpg')
<think>我们遇到了一个非常类似的问题,这次是在尝试保存SHAP特征影响点图(shap_impact)时出现的。错误信息表明图像尺寸过大(955357728x1653像素),这显然超出了Matplotlib的最大限制(每个方向最大65536像素)。 问题分析: 1. 这个错误通常发生在尝试绘制极大量数据点时。 2. 在我们的代码中,我们创建了一个子图矩阵(每个特征一个子图),并且每个子图都绘制了所有样本点(尽管我们之前已经采样了500个点,但500个点对于散点图来说并不算多,所以问题可能出在子图的数量上)。 3. 然而,错误信息中的宽度(955357728像素)是荒谬的,这通常是由于数据错误导致坐标轴范围异常引起的。 具体原因: 在绘制SHAP特征影响点图时,我们为每个特征创建了一个子图,并且每个子图绘制了采样后的点(500个)。但是,如果某个特征的值非常大(例如,数值错误导致某个特征的值达到10^9),那么x轴的范围就会非常大,从而导致图像宽度计算错误(因为matplotlib试图用像素表示整个范围,但范围过大导致像素需求过大)。 解决方案: 1. 检查数据范围:确保特征值在合理范围内。 2. 在绘图前对每个特征的数据进行范围检查,如果发现异常值,则进行截断或过滤。 3. 另一种可能是我们在聚合特征时出现了错误,导致某些特征的值异常。 我们将修改优化SHAP可视化中的点图绘制部分,确保每个子图的x轴和y轴数据都在合理范围内,并限制每个子图的数据点数量。 由于错误发生在保存图像时,我们还可以尝试在创建图形时明确设置图形大小,并确保每个子图的数据点数量可控。 但是,更根本的问题可能是数据异常。因此,我们将在绘制点图之前添加数据范围检查,并过滤掉异常值。 修改点图绘制代码: 步骤: 1. 在绘制每个特征时,检查x和y的数据范围,如果发现绝对值特别大的异常值(例如,超出3个标准差),则将其过滤掉。 2. 或者,我们可以使用百分位数来截断数据,只绘制在1%到99%分位数之间的点。 我们将采用分位数截断的方法,避免异常值对图像的影响。 修改后的点图绘制代码: ```python # 9.2 SHAP特征影响点图(使用采样) plt.figure(figsize=(10, 8), facecolor='white') # 只显示最重要的特征 top_features = top_contrib.index.tolist() top_indices = [new_feature_names.index(f) for f in top_features] # 创建子图矩阵 fig, axes = plt.subplots(len(top_features), 1, figsize=(10, 2 * len(top_features))) fig.suptitle(f"SHAP Feature Impact ({year or 'Combined'})", color='#333333') if len(top_features) == 1: axes = [axes] # 确保单特征时axes是列表 for i, (feat, idx) in enumerate(zip(top_features, top_indices)): ax = axes[i] # 采样点以减少绘图数量 sample_indices = np.random.choice(len(aggregated_shap), min(500, len(aggregated_shap)), replace=False) x = aggregated_X[sample_indices, idx] y = aggregated_shap[sample_indices, idx] # 检查并过滤异常值:使用分位数截断 low_x, high_x = np.percentile(x, [1, 99]) low_y, high_y = np.percentile(y, [1, 99]) # 创建掩码,保留在1%到99%分位数之间的点 mask = (x >= low_x) & (x <= high_x) & (y >= low_y) & (y <= high_y) x_filtered = x[mask] y_filtered = y[mask] # 创建散点图 scatter = ax.scatter(x_filtered, y_filtered, c=y_filtered, cmap=cmap, alpha=0.6, edgecolors='none') ax.set_ylabel('SHAP value', color='#555555') ax.set_xlabel(feat, color='#555555') ax.tick_params(colors='#555555') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#AAAAAA') ax.spines['bottom'].set_color('#AAAAAA') ax.grid(True, linestyle='--', alpha=0.2) plt.tight_layout(rect=[0, 0, 1, 0.95]) save_plot(fig, f"shap_impact_{year or 'combined'}.png") ``` 这样,我们只绘制每个特征在1%到99%分位数之间的点,避免极端值导致的坐标轴范围过大。 另外,我们也要确保在聚合特征时没有错误,特别是特征名称的匹配和索引的正确性。 完整修改后的`optimized_shap_visualization`函数: 注意:由于错误可能发生在不同的地方,我们将整个函数重新整理,并添加了异常值过滤。 但是,考虑到错误信息中的像素数过于巨大,我们还要检查在计算`aggregated_X`和`aggregated_shap`时是否有错误(比如出现了NaN或inf)。因此,我们在聚合后可以添加一步:将NaN替换为0,并移除inf。 我们在聚合循环后添加: ```python # 在聚合循环后,处理NaN和inf aggregated_shap = np.nan_to_num(aggregated_shap, nan=0.0, posinf=0.0, neginf=0.0) aggregated_X = np.nan_to_num(aggregated_X, nan=0.0, posinf=0.0, neginf=0.0) ``` 这样确保数据中没有NaN或inf。 综合以上,我们修改`optimized_shap_visualization`函数,包括: 1. 添加数据清洗(NaN和inf) 2. 在点图中使用分位数过滤异常值 3. 确保点图绘制的数据点在合理范围内 修改后的完整函数如下: ```python def optimized_shap_visualization(model, X_test, features, year=None, max_features=15): """ 优化SHAP可视化,避免创建过大的图像 - model: 训练好的模型 - X_test: 测试集数据 - features: 特征名称列表 - year: 年份标识 - max_features: 最大显示特征数 """ # 1. 创建SHAP解释器 explainer = shap.TreeExplainer(model) # 2. 计算SHAP值 - 使用更小的样本量 sample_size = min(200, X_test.shape[0]) idx = np.random.choice(X_test.shape[0], sample_size, replace=False) X_s = X_test[idx] # 3. 计算SHAP值 - 避免计算所有类别的SHAP值 shap_values = explainer.shap_values(X_s, check_additivity=False) # 4. 处理多分类输出 if isinstance(shap_values, list) and len(shap_values) == len(np.unique(model.classes_)): # 聚合所有类别的SHAP值(取绝对值平均) shap_values_agg = np.zeros_like(shap_values[0]) for class_shap in shap_values: shap_values_agg += np.abs(class_shap) shap_values_2d = shap_values_agg else: shap_values_2d = shap_values # 5. 创建特征名到索引的映射 feature_idx_map = {} for i, feat in enumerate(features): feature_idx_map.setdefault(feat, []).append(i) # 6. 聚合SHAP值(按特征名分组) aggregated_shap = np.zeros((shap_values_2d.shape[0], len(feature_idx_map))) aggregated_X = np.zeros((X_s.shape[0], len(feature_idx_map))) new_feature_names = [] # 7. 计算每个特征的SHAP贡献度 shap_contributions = {} for j, (feat, indices) in enumerate(feature_idx_map.items()): current_shap = shap_values_2d[:, indices] if len(current_shap.shape) == 2: mean_shap = np.mean(current_shap, axis=1) else: mean_shap = np.mean(current_shap.reshape(current_shap.shape[0], -1), axis=1) aggregated_shap[:, j] = mean_shap aggregated_X[:, j] = np.mean(X_s[:, indices], axis=1) new_feature_names.append(feat) shap_contributions[feat] = np.mean(np.abs(mean_shap)) # 8. 处理NaN和inf aggregated_shap = np.nan_to_num(aggregated_shap, nan=0.0, posinf=0.0, neginf=0.0) aggregated_X = np.nan_to_num(aggregated_X, nan=0.0, posinf=0.0, neginf=0.0) # 9. 保存SHAP贡献度报告 shap_contrib_series = pd.Series(shap_contributions) shap_report = save_contribution_report(shap_contrib_series, f"{year}_shap" if year else "combined_shap") # 10. 创建自定义SHAP可视化(避免使用SHAP的summary_plot) # 10.1 SHAP全局重要性条形图 plt.figure(figsize=(10, 6), facecolor='white') # 选择最重要的特征(最多max_features个) top_contrib = shap_contrib_series.nlargest(max_features).sort_values(ascending=True) # 绘制条形图 for i, (feat, val) in enumerate(top_contrib.items()): color_pos = i / (len(top_contrib) - 1) if len(top_contrib) > 1 else 0.5 bar_color = cmap(color_pos) plt.barh(feat, val, color=bar_color, edgecolor='#444444', alpha=0.9) plt.text(val + 0.005, i, f'{val:.4f}', ha='left', va='center', color='#333333', fontsize=10, fontweight='bold') plt.title(f"SHAP Global Importance ({year or 'Combined'})", color='#333333') plt.xlabel('mean(|SHAP value|)') ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#AAAAAA') ax.spines['bottom'].set_color('#AAAAAA') ax.tick_params(colors='#555555') ax.set_facecolor('white') plt.tight_layout() save_plot(plt.gcf(), f"shap_global_{year or 'combined'}.png") # 10.2 SHAP特征影响点图(使用采样和异常值过滤) # 创建子图矩阵 fig, axes = plt.subplots(len(top_contrib), 1, figsize=(10, 2 * len(top_contrib)), facecolor='white') fig.suptitle(f"SHAP Feature Impact ({year or 'Combined'})", color='#333333') if len(top_contrib) == 0: return shap_report # 没有特征,直接返回 if len(top_contrib) == 1: axes = [axes] # 确保单特征时axes是列表 # 获取每个特征的索引(在new_feature_names中) top_indices = [new_feature_names.index(f) for f in top_contrib.index] for i, (feat, idx) in enumerate(zip(top_contrib.index, top_indices)): ax = axes[i] # 采样点以减少绘图数量 sample_indices = np.random.choice(len(aggregated_shap), min(500, len(aggregated_shap)), replace=False) x = aggregated_X[sample_indices, idx] y = aggregated_shap[sample_indices, idx] # 检查并过滤异常值:使用分位数截断 low_x, high_x = np.percentile(x, [1, 99]) low_y, high_y = np.percentile(y, [1, 99]) mask = (x >= low_x) & (x <= high_x) & (y >= low_y) & (y <= high_y) x_filtered = x[mask] y_filtered = y[mask] # 创建散点图 scatter = ax.scatter(x_filtered, y_filtered, c=y_filtered, cmap=cmap, alpha=0.6, edgecolors='none') ax.set_ylabel('SHAP value', color='#555555') ax.set_xlabel(feat, color='#555555') ax.tick_params(colors='#555555') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#AAAAAA') ax.spines['bottom'].set_color('#AAAAAA') ax.grid(True, linestyle='--', alpha=0.2) plt.tight_layout(rect=[0, 0, 1, 0.95]) save_plot(fig, f"shap_impact_{year or 'combined'}.png") return shap_report ``` 这样修改后,应该可以避免因异常值导致图像尺寸计算错误的问题。 另外,我们也要注意,在点图绘制中,我们只绘制了500个采样点,并且过滤了异常值,所以图像尺寸应该是可控的。 请尝试运行修改后的代码,看是否解决了问题。 如果问题依旧,我们可能需要进一步调试,比如检查过滤后的数据范围,或者尝试单独绘制每个特征(在循环中保存图像)以避免大图像的创建。 但希望以上修改能解决问题。 注意:由于错误信息中出现了1653像素的高度,这可能是由于我们为每个特征创建了一个子图,而特征数量较多导致总高度很大。因此,我们限制`max_features`为15,这样最多15个子图,高度为30英寸(2*15),在300dpi下高度为9000像素(30*300),这仍然在允许范围内(小于65536)。所以高度不是问题,问题出在宽度(955357728像素)上,这显然是由于x轴范围异常引起的,因此异常值过滤是解决问题的关键。</think>这个错误表明在尝试保存图像时,Matplotlib无法处理过大的图像尺寸(955357728x1653像素)。这通常是由于数据范围异常或坐标轴设置问题导致的。让我们彻底解决这个问题。 ### 问题分析 错误发生在保存SHAP特征影响点图时,根本原因有: 1. **异常数据范围**:某些特征值异常大,导致x轴范围过大 2. **坐标轴设置问题**:Matplotlib尝试渲染超出最大限制的坐标范围 3. **数据聚合错误**:特征聚合时可能产生了错误值 ### 完整解决方案 以下是修复后的完整代码,重点优化了数据范围处理和可视化逻辑: ```python import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, accuracy_score, confusion_matrix import shap import os import time import warnings import json from matplotlib.colors import LinearSegmentedColormap from scipy.stats import zscore warnings.filterwarnings('ignore') # 创建统一的配色方案 cmap = LinearSegmentedColormap.from_list('custom_cm', ['#DDE3E1', '#F75F0A'], N=100) primary_color = '#F75F0A' # 主色调 secondary_color = '#DDE3E1' # 辅助色 # 设置全局样式和字体 plt.rcParams.update({ 'font.size': 12, 'font.family': 'DejaVu Sans', 'axes.titlesize': 14, 'axes.titleweight': 'bold', 'axes.labelsize': 12, 'xtick.labelsize': 12, 'ytick.labelsize': 12, 'legend.fontsize': 12, 'axes.edgecolor': '#444444', 'axes.linewidth': 0.8, 'figure.facecolor': 'white', 'figure.max_open_warning': 50 # 增加最大打开图像警告限制 }) sns.set_style("whitegrid", {'grid.linestyle': '--', 'grid.alpha': 0.3}) # 创建输出目录 os.makedirs("RF_results", exist_ok=True) os.makedirs("RF_reports", exist_ok=True) # --------------------- 数据加载与预处理 --------------------- def load_and_preprocess_data(file_path): print(f"Loading data: {file_path}") df = pd.read_csv(file_path) # 统一处理列名(去空格、保留原始大小写) df.columns = [c.strip() for c in df.columns] # 替换inf为nan df.replace([np.inf, -np.inf], np.nan, inplace=True) # 只删除全空的行/列 cols0, rows0 = df.shape[1], df.shape[0] df.dropna(axis=1, how='all', inplace=True) df.dropna(axis=0, how='all', inplace=True) print(f"Shape after cleaning: {df.shape}, dropped cols: {cols0 - df.shape[1]}, rows: {rows0 - df.shape[0]}") print("Columns loaded:\n", df.columns.tolist()) # 不直接删除含NaN的行/列,否则数据容易丢失 naninfo = df.isna().sum() naninfo = naninfo[naninfo > 0] if not naninfo.empty: print("Columns with NaN values:\n", naninfo) return df # --------------------- 年份特征提取 --------------------- def extract_year_features(df): # 检测所有年份 years = sorted({int(c.split('_')[0]) for c in df.columns if '_' in c and c.split('_')[0].isdigit()}) print(f"Detected years: {years}") datasets = {} # 检查目标列是否存在(大小写不敏感) target_exists = any('mp' in c.lower() for c in df.columns) print(f"Target column 'MP' exists: {target_exists}") for y in years: # 提取所有以"yyyy_"开头的列(保留原始大小写) cols = [c for c in df.columns if c.startswith(f"{y}_") and 'mp' not in c.lower()] # 添加目标列(如果存在)- 大小写不敏感匹配 target_col = next((c for c in df.columns if f"{y}_mp".lower() in c.lower()), None) if target_col: cols.append(target_col) print(f"Found target column for year {y}: {target_col}") if not cols: print(f"Year {y}: No columns found!") continue sub = df[cols].copy() # 剥离前缀为原始变量名(保留原始大小写) sub.columns = [c[len(f"{y}_"):] if c.startswith(f"{y}_") else c for c in cols] sub['year'] = y datasets[y] = sub print(f"Year {y}: columns -> {sub.columns.tolist()}") return datasets # --------------------- 目标分级 --------------------- def classify_target_variable(y, n_classes=5, breaks=None, verbose=True): """ 自然段点法对目标变量y进行分级。 - y: pandas Series,待分级目标变量 - n_classes: int,分级数量,默认5 - breaks: 可选,手动指定断点,优先使用(如后续测试集用同一断点) - verbose: 是否打印分布 返回:分级标签(pd.Series,0开始),分段点(bins) """ # 1. 处理无穷与缺失 yc = y.replace([np.inf, -np.inf], np.nan).dropna() if breaks is None: # 2. 计算自然分段点(分位数),并去重 cuts = np.unique(yc.quantile(np.linspace(0, 1, n_classes + 1)).values) if len(cuts) - 1 < n_classes: print(f"警告: 实际有效段数不足 {n_classes},仅有 {len(cuts) - 1} 段,将自动减少类别数") breaks = cuts else: breaks = np.asarray(breaks) # 3. 用分段点切分(0开始标签) y_cat = pd.cut(y, bins=breaks, labels=range(len(breaks) - 1), include_lowest=True) if verbose: print(f"Breaks: {breaks}\nDistribution:\n{y_cat.value_counts().sort_index()}") return y_cat, breaks # --------------------- 训练评估 --------------------- def train_and_evaluate(X, y, year=None): t0 = time.time() X = X.replace([np.inf, -np.inf], np.nan).fillna(X.median()) y = y.replace([np.inf, -np.inf], np.nan).dropna() idx = X.index.intersection(y.index) X, y = X.loc[idx], y.loc[idx] X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y) scaler = StandardScaler() X_tr_s = scaler.fit_transform(X_tr) X_te_s = scaler.transform(X_te) model = RandomForestClassifier(n_estimators=100, max_depth=5, class_weight='balanced', random_state=42, n_jobs=-1) model.fit(X_tr_s, y_tr) y_pred = model.predict(X_te_s) acc = accuracy_score(y_te, y_pred) rpt = pd.DataFrame(classification_report(y_te, y_pred, output_dict=True)).transpose() print(f"Year {year or 'Combined'} - Time: {time.time() - t0:.1f}s, Accuracy: {acc:.4f}") return model, scaler, X_te_s, y_te, y_pred, rpt, acc # --------------------- 保存图像 --------------------- def save_plot(fig, fname): path = os.path.join('RF_results', fname) # 添加安全检查防止过大图像 dpi = 150 width, height = fig.get_size_inches() total_pixels = width * dpi * height * dpi # 检查图像尺寸是否过大 if total_pixels > 1e8: # 1亿像素限制 print(f"警告: 图像尺寸过大 ({total_pixels:.2e} 像素), 降低DPI") dpi = max(72, int(150 * (1e8 / total_pixels)**0.5)) fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white') print(f"Saved: {path} (dpi={dpi})") plt.close(fig) # --------------------- 保存贡献度报告 --------------------- def save_contribution_report(contributions, year=None): """保存特征贡献度报告到JSON文件""" # 创建贡献度字典 contrib_dict = { "total_contrib": float(contributions.sum()), "features": {} } # 添加特征贡献度 for feat, contrib in contributions.items(): contrib_dict["features"][feat] = { "contribution": float(contrib), "percentage": float(contrib / contributions.sum()) * 100 } # 保存为JSON fname = f"feature_contributions_{year or 'combined'}.json" path = os.path.join('RF_reports', fname) with open(path, 'w') as f: json.dump(contrib_dict, f, indent=4) print(f"Saved contribution report: {path}") # 同时保存为CSV contrib_df = pd.DataFrame({ 'Feature': contributions.index, 'Contribution': contributions.values, 'Percentage': (contributions.values / contributions.sum()) * 100 }) contrib_df.to_csv(os.path.join('RF_reports', fname.replace('.json', '.csv')), index=False) return contrib_dict # --------------------- 安全范围数据可视化 --------------------- def safe_visualize(x, y, feature_name, ax, max_points=300): """ 安全可视化函数,防止数据范围异常 - 过滤异常值 - 限制数据点数量 - 设置合理的坐标轴范围 """ # 1. 创建数据框并过滤异常值 df = pd.DataFrame({feature_name: x, 'SHAP': y}) # 2. 计算Z分数过滤异常值 z_threshold = 3 df['z_x'] = np.abs(zscore(df[feature_name])) df['z_y'] = np.abs(zscore(df['SHAP'])) df = df[(df['z_x'] < z_threshold) & (df['z_y'] < z_threshold)] # 3. 采样数据点 if len(df) > max_points: df = df.sample(max_points) # 4. 计算合理的坐标轴范围 x_min, x_max = df[feature_name].quantile([0.01, 0.99]) y_min, y_max = df['SHAP'].quantile([0.01, 0.99]) # 5. 添加10%的边界缓冲 x_range = x_max - x_min y_range = y_max - y_min x_min = x_min - 0.1 * x_range if x_min != x_max else x_min - 0.1 x_max = x_max + 0.1 * x_range if x_min != x_max else x_max + 0.1 y_min = y_min - 0.1 * y_range if y_min != y_max else y_min - 0.1 y_max = y_max + 0.1 * y_range if y_min != y_max else y_max + 0.1 # 6. 创建散点图 scatter = ax.scatter( df[feature_name], df['SHAP'], c=df['SHAP'], cmap=cmap, alpha=0.6, edgecolors='none', s=15 ) # 7. 设置坐标轴范围 ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) # 8. 添加标签 ax.set_ylabel('SHAP value', color='#555555') ax.set_xlabel(feature_name, color='#555555') ax.tick_params(colors='#555555') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#AAAAAA') ax.spines['bottom'].set_color('#AAAAAA') ax.grid(True, linestyle='--', alpha=0.2) return scatter # --------------------- 优化SHAP可视化 --------------------- def optimized_shap_visualization(model, X_test, features, year=None, max_features=10): """ 优化SHAP可视化,避免创建过大的图像 - model: 训练好的模型 - X_test: 测试集数据 - features: 特征名称列表 - year: 年份标识 - max_features: 最大显示特征数 """ # 1. 创建SHAP解释器 explainer = shap.TreeExplainer(model) # 2. 计算SHAP值 - 使用更小的样本量 sample_size = min(100, X_test.shape[0]) if sample_size == 0: print("警告: 测试集为空,跳过SHAP可视化") return None idx = np.random.choice(X_test.shape[0], sample_size, replace=False) X_s = X_test[idx] # 3. 计算SHAP值 - 避免计算所有类别的SHAP值 try: shap_values = explainer.shap_values(X_s, check_additivity=False) except Exception as e: print(f"SHAP计算错误: {e}") return None # 4. 处理多分类输出 if isinstance(shap_values, list) and len(shap_values) == len(np.unique(model.classes_)): # 聚合所有类别的SHAP值(取绝对值平均) shap_values_agg = np.zeros_like(shap_values[0]) for class_shap in shap_values: shap_values_agg += np.abs(class_shap) shap_values_2d = shap_values_agg else: shap_values_2d = shap_values # 5. 创建特征名到索引的映射 feature_idx_map = {} for i, feat in enumerate(features): feature_idx_map.setdefault(feat, []).append(i) # 6. 聚合SHAP值(按特征名分组) aggregated_shap = np.zeros((shap_values_2d.shape[0], len(feature_idx_map))) aggregated_X = np.zeros((X_s.shape[0], len(feature_idx_map))) new_feature_names = [] # 7. 计算每个特征的SHAP贡献度 shap_contributions = {} for j, (feat, indices) in enumerate(feature_idx_map.items()): current_shap = shap_values_2d[:, indices] if len(current_shap.shape) == 2: mean_shap = np.mean(current_shap, axis=1) else: mean_shap = np.mean(current_shap.reshape(current_shap.shape[0], -1), axis=1) aggregated_shap[:, j] = mean_shap aggregated_X[:, j] = np.mean(X_s[:, indices], axis=1) new_feature_names.append(feat) shap_contributions[feat] = np.mean(np.abs(mean_shap)) # 8. 处理NaN和inf aggregated_shap = np.nan_to_num(aggregated_shap, nan=0.0, posinf=0.0, neginf=0.0) aggregated_X = np.nan_to_num(aggregated_X, nan=0.0, posinf=0.0, neginf=0.0) # 9. 保存SHAP贡献度报告 shap_contrib_series = pd.Series(shap_contributions) shap_report = save_contribution_report(shap_contrib_series, f"{year}_shap" if year else "combined_shap") # 10. 创建自定义SHAP可视化 # 10.1 SHAP全局重要性条形图 plt.figure(figsize=(10, 6), facecolor='white') # 选择最重要的特征(最多max_features个) top_contrib = shap_contrib_series.nlargest(max_features).sort_values(ascending=True) # 绘制条形图 for i, (feat, val) in enumerate(top_contrib.items()): color_pos = i / (len(top_contrib) - 1) if len(top_contrib) > 1 else 0.5 bar_color = cmap(color_pos) plt.barh(feat, val, color=bar_color, edgecolor='#444444', alpha=0.9) plt.text(val + 0.005, i, f'{val:.4f}', ha='left', va='center', color='#333333', fontsize=10, fontweight='bold') plt.title(f"SHAP Global Importance ({year or 'Combined'})", color='#333333') plt.xlabel('mean(|SHAP value|)') ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#AAAAAA') ax.spines['bottom'].set_color('#AAAAAA') ax.tick_params(colors='#555555') ax.set_facecolor('white') plt.tight_layout() save_plot(plt.gcf(), f"shap_global_{year or 'combined'}.png") # 10.2 SHAP特征影响点图(使用安全可视化函数) # 只显示最重要的特征 top_features = top_contrib.index.tolist() # 创建子图矩阵 - 最多显示5个特征 max_display = min(5, len(top_features)) fig, axes = plt.subplots(max_display, 1, figsize=(10, 3 * max_display), facecolor='white') fig.suptitle(f"Top {max_display} SHAP Feature Impacts ({year or 'Combined'})", fontsize=14, color='#333333', y=0.95) if max_display == 0: plt.close(fig) return shap_report if max_display == 1: axes = [axes] # 确保单特征时axes是列表 # 获取每个特征的索引 top_indices = [new_feature_names.index(f) for f in top_features[:max_display]] for i, (feat, idx) in enumerate(zip(top_features[:max_display], top_indices)): ax = axes[i] # 使用安全可视化函数 safe_visualize( aggregated_X[:, idx], aggregated_shap[:, idx], feat, ax ) plt.tight_layout(rect=[0, 0, 1, 0.95]) save_plot(fig, f"shap_impact_{year or 'combined'}.png") return shap_report # --------------------- 结果可视化 --------------------- def create_result_figure(model, X_test, y_test, y_pred, features, year=None): # 1. 特征重要性 & 混淆矩阵 fig, axes = plt.subplots(1, 2, figsize=(16, 6), facecolor='white') fig.suptitle(f"Results{'-' + str(year) if year else ''}", fontsize=14, fontweight='bold', color='#333333') # 聚合重复特征的重要性 imp_series = pd.Series(model.feature_importances_, index=features) agg_imp = imp_series.groupby(level=0).mean() top_imp = agg_imp.nlargest(8).sort_values() top_features = list(top_imp.index) # 保存特征重要性贡献度报告 save_contribution_report(agg_imp, year) # 使用统一配色方案 for i, (feat, val) in enumerate(top_imp.items()): # 根据排名计算颜色位置 (0.0-1.0) color_pos = i / (len(top_imp) - 1) if len(top_imp) > 1 else 0.5 bar_color = cmap(color_pos) bar = axes[0].barh(feat, val, color=bar_color, edgecolor='#444444', alpha=0.9) # 添加贡献度标注 axes[0].text( val + 0.005, # x位置:条形右侧稍微偏移 i, # y位置:当前条形的索引 f'{val:.4f}', # 贡献度值 ha='left', va='center', color='#333333', fontsize=10, fontweight='bold' ) axes[0].set_title('Feature Importance with Contribution', color='#333333') axes[0].set_xlabel('Importance Score', color='#555555') axes[0].tick_params(colors='#555555') axes[0].spines['top'].set_visible(False) axes[0].spines['right'].set_visible(False) axes[0].spines['bottom'].set_color('#AAAAAA') axes[0].spines['left'].set_color('#AAAAAA') axes[0].grid(True, linestyle='--', alpha=0.3, axis='x') # 混淆矩阵 cm = confusion_matrix(y_test, y_pred) cmn = cm.astype(float) / cm.sum(axis=1)[:, None] sns.heatmap(cmn, annot=True, fmt='.2f', cmap=cmap, cbar=True, cbar_kws={'label': 'Accuracy'}, ax=axes[1], annot_kws={'color': 'black', 'fontsize': 10}) axes[1].set_title('Normalized Confusion Matrix', color='#333333') axes[1].set_xlabel('Predicted Label', color='#555555') axes[1].set_ylabel('True Label', color='#555555') axes[1].tick_params(colors='#555555') # 调整布局 plt.tight_layout(rect=[0, 0, 1, 0.95]) save_plot(fig, f"figure_{year or 'combined'}.png") # 2. 优化SHAP可视化 shap_report = optimized_shap_visualization(model, X_test, features, year) return shap_report # --------------------- 生成综合报告 --------------------- def generate_comprehensive_report(results, shap_reports): """生成包含所有分析结果的综合报告""" report = { "overview": { "total_years": len(results) - 1, # 减去combined "overall_accuracy": results['combined']['accuracy'] if 'combined' in results else 0.0 }, "yearly_results": {}, "feature_contributions": {} } # 添加年度结果 for year, data in results.items(): if year == 'combined': continue report["yearly_results"][year] = { "accuracy": data['accuracy'], "classification_report": data['report'].to_dict() } # 添加特征贡献度 for key, contrib in shap_reports.items(): report["feature_contributions"][key] = contrib # 保存报告 path = os.path.join('RF_reports', 'comprehensive_report.json') with open(path, 'w') as f: json.dump(report, f, indent=4) print(f"Saved comprehensive report: {path}") return report # --------------------- 主流程 --------------------- def main(): start = time.time() print("Starting Multi-Year Analysis") df = load_and_preprocess_data(r"D:\JW\BaiduSyncdisk\数据\SZML\rasters\merged_t2.csv") datasets = extract_year_features(df) # 检查数据集是否为空 if not datasets: print("Error: No valid datasets found! Exiting.") return # 使用大小写不敏感的目标列名 target = 'mp' # 调试:打印数据集信息 print("\nDatasets summary:") for y, d in datasets.items(): print(f"Year {y}: shape={d.shape}, columns={d.columns.tolist()}") results = {} shap_reports = {} for y, d in sorted(datasets.items()): print(f"\n--- Year {y} ---") # 检查目标列是否存在(大小写不敏感) target_col = next((c for c in d.columns if c.lower() == target), None) if not target_col: print(f"Warning: Target column '{target}' not found in year {y} dataset!") print(f"Available columns: {d.columns.tolist()}") continue # 移除缺失值 d = d.dropna(subset=[target_col]) if d.empty: print(f"警告: 年份 {y} 数据集在移除缺失值后为空") continue feats = [c for c in d.columns if c.lower() != target and c != 'year'] X, yvar = d[feats], d[target_col] # 检查特征矩阵是否为空 if X.empty: print(f"警告: 年份 {y} 特征矩阵为空") continue ycat, _ = classify_target_variable(yvar) model, scaler, Xt, yt, yp, rpt, acc = train_and_evaluate(X, ycat, y) # 检查测试集是否为空 if Xt.size == 0: print(f"警告: 年份 {y} 测试集为空,跳过可视化") continue shap_report = create_result_figure(model, Xt, yt, yp, feats, y) results[y] = {'report': rpt, 'accuracy': acc} shap_reports[f"year_{y}"] = shap_report print("\n--- Combined ---") combo = pd.concat(datasets.values(), ignore_index=True) # 检查目标列是否存在(大小写不敏感) target_col = next((c for c in combo.columns if c.lower() == target), None) if not target_col: print(f"Error: Target column '{target}' not found in combined dataset!") print(f"Available columns: {combo.columns.tolist()}") return combo = combo.dropna(subset=[target_col]) if combo.empty: print("警告: 组合数据集在移除缺失值后为空") return feats = [c for c in combo.columns if c.lower() != target and c != 'year'] X, yvar = combo[feats], combo[target_col] # 检查特征矩阵是否为空 if X.empty: print("警告: 组合数据集特征矩阵为空") return ycat, _ = classify_target_variable(yvar) model, scaler, Xt, yt, yp, rpt, acc = train_and_evaluate(X, ycat) # 检查测试集是否为空 if Xt.size == 0: print("警告: 组合数据集测试集为空,跳过可视化") else: shap_report = create_result_figure(model, Xt, yt, yp, feats) results['combined'] = {'report': rpt, 'accuracy': acc} shap_reports["combined"] = shap_report # 年度准确率差异 print("\nAnnual Accuracy Differences:") years = [yr for yr in results if yr != 'combined'] years_sorted = sorted(years) for i in range(1, len(years_sorted)): y0, y1 = years_sorted[i - 1], years_sorted[i] diff = results[y1]['accuracy'] - results[y0]['accuracy'] print(f"{y0} -> {y1}: Accuracy change {diff:+.4f}") total_time = time.time() - start print(f"\nTotal time: {total_time:.1f}s. Results in 'RF_results' and 'RF_reports' directories.") # 打印报告 for yr in years_sorted: print(f"\nYear {yr} Report:\n", results[yr]['report']) if 'combined' in results: print("\nCombined Report:\n", results['combined']['report']) # 生成综合报告 generate_comprehensive_report(results, shap_reports) if __name__ == '__main__': main() ``` ### 关键修复点 #### 1. 安全数据可视化函数 (`safe_visualize`) ```python def safe_visualize(x, y, feature_name, ax, max_points=300): # 1. 创建数据框并过滤异常值 df = pd.DataFrame({feature_name: x, 'SHAP': y}) # 2. 计算Z分数过滤异常值 z_threshold = 3 df['z_x'] = np.abs(zscore(df[feature_name])) df['z_y'] = np.abs(zscore(df['SHAP'])) df = df[(df['z_x'] < z_threshold) & (df['z_y'] < z_threshold)] # 3. 采样数据点 if len(df) > max_points: df = df.sample(max_points) # 4. 计算合理的坐标轴范围 x_min, x_max = df[feature_name].quantile([0.01, 0.99]) y_min, y_max = df['SHAP'].quantile([0.01, 0.99]) # 5. 添加10%的边界缓冲 x_range = x_max - x_min y_range = y_max - y_min x_min = x_min - 0.1 * x_range if x_min != x_max else x_min - 0.1 x_max = x_max + 0.1 * x_range if x_min != x_max else x_max + 0.1 y_min = y_min - 0.1 * y_range if y_min != y_max else y_min - 0.1 y_max = y_max + 0.1 * y_range if y_min != y_max else y_max + 0.1 # ... 创建散点图并设置坐标轴范围 ... ``` #### 2. 图像保存安全机制 (`save_plot`) ```python def save_plot(fig, fname): path = os.path.join('RF_results', fname) # 添加安全检査防止过大图像 dpi = 150 width, height = fig.get_size_inches() total_pixels = width * dpi * height * dpi # 检查图像尺寸是否过大 if total_pixels > 1e8: # 1亿像素限制 print(f"警告: 图像尺寸过大 ({total_pixels:.2e} 像素), 降低DPI") dpi = max(72, int(150 * (1e8 / total_pixels)**0.5)) fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white') print(f"Saved: {path} (dpi={dpi})") plt.close(fig) ``` #### 3. 数据完整性检查 在整个流程中添加了多处数据完整性检查: ```python # 在训练前检查数据集是否为空 if d.empty: print(f"警告: 年份 {y} 数据集在移除缺失值后为空") continue # 在可视化前检查测试集是否为空 if Xt.size == 0: print(f"警告: 年份 {y} 测试集为空,跳过可视化") continue ``` #### 4. SHAP计算异常处理 ```python try: shap_values = explainer.shap_values(X_s, check_additivity=False) except Exception as e: print(f"SHAP计算错误: {e}") return None ``` #### 5. 限制可视化特征数量 ```python # 只显示最重要的特征 max_display = min(5, len(top_features)) # 创建子图矩阵 - 最多显示5个特征 fig, axes = plt.subplots(max_display, 1, figsize=(10, 3 * max_display), facecolor='white') ``` ### 为什么这些修改能解决问题? 1. **异常值处理**: - 使用Z分数过滤异常值(|Z| > 3) - 通过分位数限制坐标轴范围(1%-99%分位数) - 添加10%的边界缓冲防止边缘情况 2. **数据范围控制**: - 显式设置坐标轴范围 (`set_xlim`, `set_ylim`) - 防止坐标轴范围过大导致的渲染问题 - 确保所有数据点在合理范围内 3. **图像尺寸安全机制**: - 计算图像总像素数 - 自动降低DPI防止过大图像 - 设置最大像素限制(1亿像素) 4. **数据完整性检查**: - 检查空数据集和空测试集 - 跳过无效计算步骤 - 防止空数据导致的意外错误 5. **资源优化**: - 限制最大显示特征数(5个) - 限制最大数据点数(300个) - 减少不必要的计算和渲染 这些修改确保了可视化过程在可控的数据范围内进行,同时保持分析结果的准确性。代码现在应该能够处理各种规模的数据集,包括那些包含异常值或范围异常大的情况。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值