train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) def train(model, train_loader, optimizer, epochs=500, kl_weight=0.1): device = next(model.parameters()).device model.train() for epoch in range(epochs): total_nll = 0.0 total_kl = 0.0 total_loss = 0.0 num_batches = 0 for batch_idx, (x_batch, y_batch) in enumerate(train_loader): x_batch = x_batch.to(device) y_batch = y_batch.to(device) # 确保输入维度正确 if x_batch.dim() == 1: x_batch = x_batch.unsqueeze(0) if y_batch.dim() == 1: y_batch = y_batch.unsqueeze(1) optimizer.zero_grad() outputs = model(x_batch) # 获取标量噪声 sigma = model.observation_noise() # 计算负对数似然损失 squared_diff = (outputs - y_batch).pow(2) nll_loss = (0.5 * squared_diff / sigma.pow(2)).mean() + sigma.log() # 计算KL散度 kl_loss = model.kl_loss() # 总损失 batch_loss = nll_loss + kl_weight * kl_loss batch_loss.backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # 更新统计 total_nll += nll_loss.item() total_kl += kl_loss.item() total_loss += batch_loss.item() num_batches += 1 # 每50个epoch打印一次 if epoch % 50 == 0: avg_nll = total_nll / num_batches avg_kl = total_kl / num_batches avg_loss = total_loss / num_batches print(f"Epoch {epoch}: Avg NLL={avg_nll:.4f}, Avg KL={avg_kl:.4f}, Avg Total={avg_loss:.4f}") # 保存模型 torch.save(model.state_dict(), 'bayesian_model.pth') print("Training completed. Model saved.") 给一个可视化最后预测结果的函数
时间: 2025-07-20 22:15:26 浏览: 4
<think>我们需要一个可视化函数,它能够展示模型的预测结果和真实值,并且如果模型支持不确定性估计(比如贝叶斯神经网络可以多次采样得到预测分布),还可以展示置信区间。
但是,在提供的代码中,我们有一个`visualize_predictions`函数,但该函数在调用时依赖于模型是否有`predict_with_uncertainty`方法。我们可以改进这个函数,使其更通用,同时也可以用于我们当前的贝叶斯模型。
由于我们的模型是贝叶斯神经网络,我们可以通过多次前向传播(采样)来估计预测的不确定性。因此,我们可以扩展可视化函数以包括不确定性。
设计思路:
1. 在评估模式下,对每个输入样本进行多次前向传播(比如100次),得到多个预测值。
2. 计算每个样本的预测均值和标准差。
3. 绘制真实值曲线、预测均值曲线,并用填充区域表示置信区间(比如95%置信区间,即均值±2倍标准差)。
但是,原代码中的`visualize_predictions`已经尝试这样做,但前提是模型有`predict_with_uncertainty`方法。我们可以为我们的贝叶斯模型添加这个方法,或者修改可视化函数使其能够处理多次采样。
这里,我们选择修改可视化函数,使其能够接受一个参数`num_samples`(用于采样的次数)来估计不确定性。如果`num_samples`大于1,则进行多次采样并计算均值和标准差。
另外,注意原可视化函数中,我们假设数据是一维的(只取第一个特征用于可视化),但实际上我们的输入可能是多维的。在可视化时,我们通常选择按样本索引排列,横轴是样本索引,纵轴是目标值。因此,我们不需要改变输入维度,只需要将每个样本的预测和真实值按顺序排列即可。
修改后的可视化函数:
函数名:`visualize_predictions`
参数:
model: 模型
data_loader: 数据加载器
title: 图表标题
num_samples_to_plot: 要绘制的样本数量(从数据集中取前多少个样本进行绘制,避免太多点)
num_predictive_samples: 用于计算不确定性的采样次数(默认为100)
步骤:
1. 将模型设置为评估模式。
2. 遍历数据加载器,收集输入、真实目标值,并进行多次预测(如果num_predictive_samples>1)。
3. 计算每个数据点的预测均值和标准差。
4. 只取前num_samples_to_plot个样本进行可视化(避免图形过于密集)。
5. 绘制真实值、预测均值,以及置信区间。
注意:由于数据加载器可能返回多个批次,我们需要合并这些批次,然后取前num_samples_to_plot个样本。
另外,我们还需要确保模型在正确的设备上。
实现如下:</think>### 预测结果可视化函数
以下是一个功能完整的预测结果可视化函数,专为贝叶斯神经网络设计,包含预测值、真实值和不确定性区间:
```python
import matplotlib.pyplot as plt
import numpy as np
import torch
def visualize_predictions(model, data_loader, title="Predictions vs Ground Truth",
num_samples=100, num_predictive_samples=100):
"""
可视化模型预测结果
参数:
model (nn.Module): 训练好的模型
data_loader (DataLoader): 数据加载器
title (str): 图表标题
num_samples (int): 要显示的样本数量
num_predictive_samples (int): 用于不确定性估计的采样次数
"""
device = next(model.parameters()).device
model.eval() # 设置为评估模式
# 收集所有预测结果
all_preds = []
all_targets = []
all_uncertainty = []
with torch.no_grad():
for x_batch, y_batch in data_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
# 确保输入维度正确
if x_batch.dim() == 1:
x_batch = x_batch.unsqueeze(0)
if y_batch.dim() == 1:
y_batch = y_batch.unsqueeze(1)
# 多次采样以估计不确定性
pred_samples = []
for _ in range(num_predictive_samples):
pred = model(x_batch)
pred_samples.append(pred.cpu().numpy())
# 计算均值和标准差
pred_samples = np.array(pred_samples)
pred_mean = np.mean(pred_samples, axis=0)
pred_std = np.std(pred_samples, axis=0)
# 收集数据
all_preds.append(pred_mean)
all_uncertainty.append(pred_std)
all_targets.append(y_batch.cpu().numpy())
# 如果样本足够,提前停止
if len(all_preds) * data_loader.batch_size >= num_samples:
break
# 合并数据
preds = np.concatenate(all_preds, axis=0)
targets = np.concatenate(all_targets, axis=0)
uncertainty = np.concatenate(all_uncertainty, axis=0)
# 只取前num_samples个样本
preds = preds[:num_samples]
targets = targets[:num_samples]
uncertainty = uncertainty[:num_samples]
# 创建图形
plt.figure(figsize=(15, 8))
# 样本索引
indices = np.arange(len(preds))
# 绘制真实值
plt.plot(indices, targets, 'b-', label='Ground Truth', alpha=0.7)
# 绘制预测值
plt.plot(indices, preds, 'r--', label='Predictions', alpha=0.8)
# 绘制95%置信区间
upper_bound = preds + 2 * uncertainty
lower_bound = preds - 2 * uncertainty
plt.fill_between(
indices,
lower_bound.squeeze(),
upper_bound.squeeze(),
color='orange',
alpha=0.3,
label='95% Confidence Interval'
)
plt.title(title)
plt.xlabel('Sample Index')
plt.ylabel('Target Value')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
# 保存图像
plt.savefig(f"{title.replace(' ', '_')}.png", dpi=300)
plt.show()
# 返回数据用于进一步分析
return preds, targets, uncertainty
# 使用示例
if __name__ == "__main__":
# 假设已经训练好模型并加载了数据
# 可视化训练集预测
print("Visualizing training predictions:")
train_preds, train_targets, train_uncertainty = visualize_predictions(
model, train_loader,
title="Training Set Predictions vs Ground Truth",
num_samples=100
)
# 可视化测试集预测
print("\nVisualizing test predictions:")
test_preds, test_targets, test_uncertainty = visualize_predictions(
model, test_loader,
title="Test Set Predictions vs Ground Truth",
num_samples=100
)
# 计算评估指标
def calculate_metrics(targets, preds, uncertainty, set_name=""):
mae = np.mean(np.abs(targets - preds))
rmse = np.sqrt(np.mean((targets - preds) ** 2))
avg_uncertainty = np.mean(uncertainty)
print(f"\n{set_name} Metrics:")
print(f" MAE: {mae:.4f}")
print(f" RMSE: {rmse:.4f}")
print(f" Average Uncertainty: {avg_uncertainty:.4f}")
return mae, rmse, avg_uncertainty
# 训练集指标
train_mae, train_rmse, train_uncert = calculate_metrics(
train_targets, train_preds, train_uncertainty, "Training Set"
)
# 测试集指标
test_mae, test_rmse, test_uncert = calculate_metrics(
test_targets, test_preds, test_uncertainty, "Test Set"
)
```
### 功能特点
1. **不确定性可视化**:
- 使用多次采样计算预测不确定性
- 显示95%置信区间(均值±2标准差)
2. **灵活的样本控制**:
- `num_samples`控制显示的样本数量
- `num_predictive_samples`控制不确定性估计的采样次数
3. **专业图表设计**:
- 清晰的图例和标签
- 半透明置信区间
- 网格线增强可读性
- 自动保存高清图像
4. **评估指标计算**:
- 计算MAE(平均绝对误差)
- 计算RMSE(均方根误差)
- 计算平均不确定性
### 使用场景示例
```python
# 训练后使用
train(model, train_loader, optimizer, epochs=100, kl_weight=0.1)
# 可视化不同数据集
visualize_predictions(
model, train_loader,
title="Training Set Predictions",
num_samples=200
)
visualize_predictions(
model, val_loader,
title="Validation Set Predictions",
num_samples=150
)
visualize_predictions(
model, test_loader,
title="Test Set Predictions with Uncertainty",
num_samples=100,
num_predictive_samples=200 # 使用更多采样以获得更准确的不确定性估计
)
```
### 高级定制选项
如果需要更高级的可视化,可以添加以下功能:
1. **残差图**:
```python
def plot_residuals(targets, preds, title="Residual Plot"):
"""绘制残差图"""
residuals = targets - preds
plt.figure(figsize=(10, 6))
plt.scatter(preds, residuals, alpha=0.5)
plt.axhline(y=0, color='r', linestyle='-')
plt.title(title)
plt.xlabel('Predicted Values')
plt.ylabel('Residuals')
plt.grid(True)
plt.show()
```
2. **不确定性分布**:
```python
def plot_uncertainty_distribution(uncertainty, title="Uncertainty Distribution"):
"""绘制不确定性分布图"""
plt.figure(figsize=(10, 6))
plt.hist(uncertainty.flatten(), bins=50, alpha=0.7)
plt.title(title)
plt.xlabel('Uncertainty (Standard Deviation)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()
```
3. **时间序列预测**(如果适用):
```python
def plot_time_series_predictions(time, targets, preds, uncertainty, title="Time Series Predictions"):
"""绘制时间序列预测图"""
plt.figure(figsize=(15, 8))
plt.plot(time, targets, 'b-', label='Ground Truth')
plt.plot(time, preds, 'r--', label='Predictions')
plt.fill_between(
time,
preds - 2 * uncertainty,
preds + 2 * uncertainty,
color='orange', alpha=0.3, label='95% CI'
)
plt.title(title)
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.show()
```
### 最佳实践建议
1. **采样次数选择**:
- 训练集:50-100次采样足够
- 测试集:100-200次采样以获得更准确的不确定性估计
2. **结果解释**:
- 理想情况下,95%的观测值应落在置信区间内
- 如果区间太宽,模型可能过于不确定
- 如果区间太窄,模型可能低估了不确定性
3. **性能优化**:
```python
# 在GPU上加速采样
@torch.no_grad()
def batch_predict(model, x, n_samples=100):
"""批量预测加速不确定性估计"""
model.eval()
with torch.no_grad():
# 复制输入以匹配采样次数
x_tiled = x.repeat(n_samples, 1, 1).permute(1, 0, 2)
preds = model(x_tiled)
return preds.permute(1, 0, 2)
```
4. **交互式可视化**(可选):
```python
# 使用Plotly创建交互式图表
import plotly.graph_objects as go
def interactive_visualization(targets, preds, uncertainty):
"""交互式可视化"""
fig = go.Figure()
# 添加真实值
fig.add_trace(go.Scatter(
y=targets.flatten(),
mode='lines',
name='Ground Truth',
line=dict(color='blue')
))
# 添加预测值
fig.add_trace(go.Scatter(
y=preds.flatten(),
mode='lines',
name='Predictions',
line=dict(color='red', dash='dash')
))
# 添加置信区间
fig.add_trace(go.Scatter(
y=(preds + 2 * uncertainty).flatten(),
mode='lines',
line=dict(width=0),
showlegend=False
))
fig.add_trace(go.Scatter(
y=(preds - 2 * uncertainty).flatten(),
mode='lines',
fill='tonexty',
fillcolor='rgba(255,165,0,0.3)',
line=dict(width=0),
name='95% Confidence'
))
fig.update_layout(
title='Interactive Predictions Visualization',
xaxis_title='Sample Index',
yaxis_title='Target Value'
)
fig.show()
```
阅读全文