model=monai.netwoeks.nets.DenseNet264(spatial_dims=3, in_channels=4, out_channels=2, pretrained=False) pthfile=r'/media/sun/sort/best_metric_model_classification3d_dict_densenet264.pth' model.load_state_dict(torch.load(pthfile) model=model.cuda()这段代码的含义
时间: 2024-03-04 14:52:14 浏览: 78
这段代码的作用是:
1. 创建一个名为 `model` 的 DenseNet264 模型,该模型具有 3D 空间维度、输入通道数为 4、输出通道数为 2。
2. 确定了一个预训练权重文件的路径 `pthfile`。
3. 使用 `torch.load()` 函数从 `pthfile` 中加载预训练权重,并将其加载到 `model` 中。
4. 将 `model` 移动到 GPU 上运行,以提高模型计算的速度。
相关问题
import torch import torch.nn as nn import torch.optim as optim from monai.networks.nets import DenseNet121 from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ToTensor from torch.utils.data import DataLoader, Dataset import os # 设定 CPU 运行 device = torch.device("cpu") train_root = r"E:\dataset\train" test_root = r"E:\dataset\test" # **遍历患者文件夹** class NiftiDataset(Dataset): def __init__(self, data_dir): self.data_files = [] for patient_folder in os.listdir(data_dir): patient_path = os.path.join(data_dir, patient_folder) if os.path.isdir(patient_path): for file in os.listdir(patient_path): if file.endswith(".nii.gz"): self.data_files.append(os.path.join(patient_path, file)) # **修正 LoadImage,确保正确读取 NIfTI** self.transform = Compose([ LoadImage(image_only=True, reader="NibabelReader"), # 指定 NIfTI 读取方式 EnsureChannelFirst(), # 确保 3D 数据格式正确 ToTensor() ]) def __len__(self): return len(self.data_files) def __getitem__(self, idx): img = self.transform(self.data_files[idx]) label = 1 if "positive" in self.data_files[idx] else 0 return img, torch.tensor(label, dtype=torch.long) # **添加 `if __name__ == "__main__":` 保护代码** if __name__ == "__main__": # **加载数据** train_dataset = NiftiDataset(train_root) test_dataset = NiftiDataset(test_root) # **Windows 下 `num_workers=0` 避免多进程问题** train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=2, num_workers=0, pin_memory=True) # **定义 3D CNN 模型** model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) # 改为 CPU 运行 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # **训练 3D CNN** for epoch in range(10): model.train() for images, labels in train_loader: images, labels = images.to(device), labels.to(device) # 确保数据也在 CPU 上 optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"✅ Epoch {epoch + 1} Loss: {loss.item():.4f}") # **保存模型** torch.save(model.state_dict(), "3d_cnn_model_cpu.pth") 请你在这个基础上修改
<think>嗯,我现在需要帮助用户修改他们的PyTorch和MONAI代码。首先,我得仔细看一下他们提供的代码,了解他们正在做什么。看起来这是一个使用DenseNet121进行3D医学图像分类的脚本,可能是处理NIfTI格式的MRI或CT图像。用户提到他们想在CPU上运行,所以代码已经设置了device为cpu,但可能还有其他地方需要优化。
首先,用户的数据集类NiftiDataset遍历患者文件夹,收集所有.nii.gz文件。然后使用MONAI的变换来处理图像,包括加载、确保通道在前,转换为张量。数据加载器部分,他们设置了num_workers=0,可能是因为在Windows上多进程有问题。模型用的是DenseNet121,3D版本,输出2个通道,使用交叉熵损失和Adam优化器,训练10个epoch,最后保存模型。
接下来,我需要考虑用户可能的改进点。他们可能遇到了问题,或者想优化代码。比如,数据加载部分是否正确处理了标签?现在的标签是根据文件名中是否有"positive"来确定的,但如果有其他情况可能不适用。可能需要更健壮的标签处理方式。
数据增强方面,原代码没有添加任何数据增强步骤,医学图像通常数据量小,增加数据增强可以提高模型泛化能力。比如添加旋转、翻转等变换。MONAI提供了很多这样的变换,可以加到Compose里。
然后是数据集划分,用户现在分开train_root和test_root,但可能没有验证集。或者在数据加载时是否应该用更正式的方法划分训练集和验证集,比如随机分割?
模型部分,DenseNet121可能不是最优选择,或者超参数如学习率、优化器是否可以调整?比如加入学习率调度器,或者使用不同的优化器。
训练循环部分,现在只打印了每个epoch的loss,但没有验证步骤,无法监控模型是否过拟合。应该在每个epoch结束后在验证集上评估准确率或其它指标。
此外,代码的可读性和可维护性方面,比如是否有足够的注释,变量命名是否清晰?可能可以添加更多注释,或者将一些硬编码的参数(如batch_size, lr)提取为变量,方便修改。
还有异常处理,比如在读取文件时,是否有处理损坏文件的情况?或者数据集中是否存在不平衡,比如正负样本比例差异大,需要加权重?
另外,MONAI的LoadImage使用NibabelReader是否正确?需要确认是否所有.nii.gz文件都能被正确读取,或者是否需要其他参数。
模型保存部分,现在只保存了state_dict,是否应该保存整个模型,或者连同优化器状态一起保存,以便恢复训练?
最后,代码的结构是否有改进空间?比如将数据集定义、模型训练等部分拆分成函数,提高模块化。
可能的修改步骤:
1. 数据增强:在transform中添加随机旋转、缩放等。
2. 完善标签处理:确保标签正确,比如检查文件路径是否准确,或者从其他方式获取标签。
3. 添加验证集:在训练过程中加入验证步骤,计算准确率。
4. 改进训练循环:添加tqdm进度条,让训练过程更直观。
5. 学习率调度:添加学习率衰减,如StepLR或ReduceLROnPlateau。
6. 模型检查点:保存最佳模型,而不仅仅是最后一个epoch的。
7. 添加评估指标:如计算准确率、召回率等。
8. 处理类别不平衡:在损失函数中加权重。
9. 异常处理:在数据加载时添加try-except块,跳过错误文件。
10. 参数化:将超参数提取到配置部分,方便调整。
现在需要逐一检查这些点,并确定哪些是用户需要的。比如用户可能没有提到验证集,但这是一个常见的改进点。另外,代码中num_workers=0是因为Windows的问题,可能需要注释说明。
另外,原代码中的DenseNet121是否适合3D数据?因为MONAI的DenseNet121在spatial_dims=3时确实是3D网络,所以没问题。
可能的代码修改示例:
- 在transform中添加数据增强:
使用RandRotate、RandFlip等MONAI的变换。
- 修改标签处理,确保更可靠,比如从文件夹名获取而非文件名。
- 添加验证循环,计算准确率。
- 使用tqdm来显示进度条。
- 添加模型验证和保存最佳模型逻辑。
同时,用户可能在数据加载时遇到性能问题,因为num_workers=0可能导致加载慢,但Windows下可能无法调整。
最后,确保代码有足够的注释,方便用户理解。</think>我将从数据处理优化、模型训练改进和代码健壮性三个方面为您修改代码,并提供详细解释:
$$ \text{改进后的完整代码} $$
```python
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from monai.transforms import (
Compose, LoadImage, EnsureChannelFirst,
RandRotate90, RandFlip, ScaleIntensity,
ToTensor
)
from torch.utils.data import DataLoader, Dataset, random_split
from monai.data import Dataset as MonaiDataset
import os
import numpy as np
from tqdm import tqdm
# 配置参数集中管理
class Config:
device = torch.device("cpu")
train_ratio = 0.8 # 训练集比例
batch_size = 2
lr = 0.001
epochs = 10
num_workers = 0 # Windows系统需设为0
class NiftiDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_files = []
self.labels = []
# 递归搜索所有.nii.gz文件
for root, _, files in os.walk(data_dir):
for file in files:
if file.endswith(".nii.gz"):
file_path = os.path.join(root, file)
self.data_files.append(file_path)
# 更健壮的标签提取方式
self.labels.append(1 if "positive" in file_path.lower() else 0)
# 数据增强和预处理
self.transform = transform or Compose([
LoadImage(image_only=True, reader="NibabelReader"),
EnsureChannelFirst(),
ScaleIntensity(minv=0.0, maxv=1.0), # 标准化强度值
RandRotate90(prob=0.5, spatial_axes=(0, 1)), # 随机旋转
RandFlip(prob=0.5, spatial_axis=0), # 随机翻转
ToTensor()
])
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
try:
img = self.transform(self.data_files[idx])
return img, torch.tensor(self.labels[idx], dtype=torch.long)
except Exception as e:
print(f"加载 {self.data_files[idx]} 失败: {str(e)}")
return None
def main():
# 初始化数据集
full_dataset = NiftiDataset(train_root)
# 数据集划分
train_size = int(Config.train_ratio * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(
full_dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42) # 固定随机种子
)
# 数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=Config.batch_size,
shuffle=True,
num_workers=Config.num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=Config.batch_size,
num_workers=Config.num_workers
)
# 模型初始化
model = DenseNet121(
spatial_dims=3,
in_channels=1,
out_channels=2
).to(Config.device)
# 处理类别不平衡
class_weights = torch.tensor(
[1.0/(np.bincount(full_dataset.labels)[0]/len(full_dataset)),
1.0/(np.bincount(full_dataset.labels)[1]/len(full_dataset))],
device=Config.device
)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
# 训练循环
best_val_loss = float('inf')
for epoch in range(Config.epochs):
# 训练阶段
model.train()
train_loss = 0.0
with tqdm(train_loader, unit="batch") as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch+1}")
images, labels = batch
if images is None: # 跳过加载失败的数据
continue
images = images.to(Config.device)
labels = labels.to(Config.device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
tepoch.set_postfix(loss=loss.item())
# 验证阶段
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
if images is None:
continue
images = images.to(Config.device)
labels = labels.to(Config.device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 更新学习率
scheduler.step(val_loss)
# 保存最佳模型
if val_loss < best_val_loss:
torch.save(model.state_dict(), "best_model.pth")
best_val_loss = val_loss
# 打印统计信息
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
val_acc = 100 * correct / total
print(f"Epoch {epoch+1}/{Config.epochs}")
print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
print(f"Val Accuracy: {val_acc:.2f}%")
if __name__ == "__main__":
main()
```
$$ \text{主要改进点说明} $$
1. **数据处理优化**
- 添加递归文件搜索:使用`os.walk`遍历子目录,增强路径处理能力
- 改进数据增强策略:添加`RandRotate90`、`RandFlip`等增强方法
- 数据标准化:使用`ScaleIntensity`将强度值归一化到[0,1]区间
- 数据集划分:添加验证集,采用8:2比例划分
- 异常处理:在`__getitem__`中添加try-catch块,跳过损坏文件
2. **模型训练改进**
- 优化器升级:使用AdamW代替Adam,添加权重衰减(weight_decay=1e-4)
- 学习率调度:引入ReduceLROnPlateau自动调整学习率
- 类别不平衡处理:根据样本分布自动计算类别权重
- 训练进度可视化:集成tqdm进度条
- 模型保存策略:仅保存验证集表现最好的模型
3. **代码健壮性增强**
- 参数集中管理:通过Config类统一管理超参数
- 随机种子固定:确保实验可重复性
- 验证集评估:添加准确率计算
- 类型检查:添加数据加载的None值判断
- 内存优化:使用with torch.no_grad()减少显存占用
$$ \text{推荐后续改进方向} $$
1. 使用MONAI的`CacheDataset`加速数据加载
2. 添加早停机制(Early Stopping)
3. 实现交叉验证策略
4. 添加测试集评估模块
5. 集成TensorBoard可视化训练过程
6. 尝试不同的网络结构(如ResNet3D)
7. 实现Grad-CAM等可解释性方法
这些改进使得代码更健壮、可维护性更好,同时提升了模型性能的可控性和可复现性。建议根据具体任务需求调整数据增强策略和网络超参数。
--------------------------------------------------------------------------- ImportError Traceback (most recent call last) Cell In[9], line 7 5 import matplotlib.pyplot as plt 6 import os ----> 7 from monai.transforms import ( 8 Compose, LoadImaged, EnsureChannelFirstd, SyncAffined, Spacingd, 9 CropForegroundd, ResizeWithPadOrCropd, ScaleIntensityRanged, EnsureTyped, 10 RandFlipd, RandAffined, Lambdad 11 ) 12 import monai 13 from monai.utils import set_determinism ImportError: cannot import name 'SyncAffined' from 'monai.transforms' (D:\Anaconda\envs\DL\lib\site-packages\monai\transforms\__init__.py)
<think>我们遇到了一个导入错误:无法从monai.transforms中导入'SyncAffined'。这是因为在MONAI库中,SyncAffined可能不是标准的转换名称。根据MONAI的文档,我们通常使用Compose和其他转换,但SyncAffined可能是一个自定义转换或已经更名。根据您之前提供的代码片段,RecordSpatialInfo和SyncAffined可能是自定义转换。因此,我们需要检查这两个转换的实现。解决方案:1.移除SyncAffined转换(如果它不是必需的),或者2.提供自定义SyncAffined和RecordSpatialInfo转换的实现。由于错误中提到了SyncAffined,我们假设这些转换是自定义的,因此我们需要在代码中定义它们。我们将定义两个自定义转换类:-SyncAffined:这个转换的目的是确保图像和掩膜具有相同的仿射矩阵。如果已经相同,可以跳过。-RecordSpatialInfo:记录图像的空间信息(如形状、方向、原点、间距等)以便后续恢复。但是,由于在预处理流程中我们使用了Resize和Crop等操作,这些操作可能会改变空间信息,所以RecordSpatialInfo应该记录原始的空间信息,以便在Grad-CAM热力图生成后可以映射回原始空间。然而,在Grad-CAM的生成过程中,我们实际上只需要在预处理后记录空间信息,然后在生成热力图时将其恢复到原始大小。MONAI的转换通常会自动处理仿射信息,但如果我们希望保存热力图时使用原始图像的仿射,我们可以直接从原始图像获取。考虑到时间,我们决定简化:-移除SyncAffined和RecordSpatialInfo,因为SyncAffined可能不是标准转换,而RecordSpatialInfo我们在后续处理中并没有使用。-或者,我们可以用MONAI内置的转换来替代它们。在您的预处理中,SyncAffined的目的是什么?如果我们不需要同步仿射,我们可以移除它。因为我们加载图像和掩膜,确保它们有相同的仿射矩阵(即相同的方向和原点)可能很重要,但加载时它们应该已经具有相同的空间信息(如果来自同一坐标系)。如果没有,我们可以考虑使用MONAI的LoadImaged转换的选项来确保加载时重采样到相同的空间(通过指定`image_only=False`和`affine`处理)。我们注意到LoadImaged已经可以加载多个键(图像和掩膜),并返回包括元数据(如仿射矩阵)的字典。但是,在预处理流程中,我们通常需要确保图像和掩膜在同一个空间中进行处理。因此,我们可以使用MONAI的Spacingd转换来将它们重采样到相同的间距。而SyncAffined可能是一个自定义的同步仿射的转换,但MONAI没有内置名为SyncAffined的转换。替代方案:-使用一个自定义转换来检查图像和掩膜仿射矩阵是否匹配(如果不匹配,可以调整掩膜的仿射以匹配图像),但通常我们期望它们匹配。-由于我们使用了LoadImaged,它会返回图像和仿射矩阵。我们可以使用一个转换来确保它们匹配(使用元数据),但MONAI内部转换(如Spacingd)会自动处理多个键的仿射。因此,我建议移除SyncAffined和RecordSpatialInfo转换,因为这些转换可能不是MONAI标准转换,且我们后续没有使用RecordSpatialInfo的信息。修改后的预处理转换:```pythondeterministic_transforms=Compose([LoadImaged(keys=["image","mask"],image_only=False,reader="ITKReader"),EnsureChannelFirstd(keys=["image","mask"]),#移除了SyncAffinedSpacingd(keys=["image","mask"],pixdim=(1.0,1.0,1.0),mode=("bilinear","nearest")),CropForegroundd(keys=["image","mask"],source_key="mask",margin=10),ResizeWithPadOrCropd(keys=["image","mask"],spatial_size=(64,64,64)),ScaleIntensityRanged(keys=["image"],a_min=20,a_max=80,b_min=0.0,b_max=1.0,clip=True),EnsureTyped(keys=["image","mask"],data_type="tensor"),],map_items=True,overrides={"allow_missing_keys":True})```这样,我们就移除了SyncAffined和RecordSpatialInfo。同时,在生成热力图时,我们使用原始图像的仿射矩阵(在nibabel加载时获取)来保存热力图,所以不需要恢复中间的空间信息。现在修改Grad-CAM代码,移除对这两个转换的引用。我们将重新提供完整的代码,移除SyncAffined和RecordSpatialInfo,并更新注释。</think>
我注意到您在导入MONAI转换时遇到了`SyncAffined`问题。这个问题通常是由于MONAI版本差异或该转换在较新版本中已被取代造成的。让我提供完整的修复方案:
```python
import torch
import torch.nn as nn
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import os
import monai
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
CropForegroundd, ResizeWithPadOrCropd, ScaleIntensityRanged, EnsureTyped,
RandFlipd, RandAffined, Lambdad
)
from monai.utils import set_determinism
from typing import Dict, Tuple
# 修复:添加缺失的自定义转换实现
class SyncAffined:
"""自定义转换以确保图像和掩膜具有相同的仿射矩阵"""
def __init__(self, keys, atol=1e-10):
self.keys = keys
self.atol = atol
def __call__(self, data):
# 获取第一个图像作为参考
ref_affine = data[self.keys[0] + "_meta_dict"]["affine"]
ref_spatial_shape = data[self.keys[0] + "_meta_dict"]["spatial_shape"]
for key in self.keys[1:]:
# 检查并校正仿射矩阵差异
if not np.allclose(data[key + "_meta_dict"]["affine"], ref_affine, atol=self.atol):
# 创建新图像数据,使用参考仿射矩阵
from monai.transforms import Affine
affine_transform = Affine(
affine=ref_affine,
spatial_size=ref_spatial_shape,
mode="bilinear" if "image" in key else "nearest"
)
# 应用仿射变换
data[key] = affine_transform(data[key])
# 更新元数据
data[key + "_meta_dict"]["affine"] = ref_affine.copy()
data[key + "_meta_dict"]["spatial_shape"] = ref_spatial_shape
return data
class RecordSpatialInfo:
"""记录空间信息以供后续使用"""
def __init__(self, keys):
self.keys = keys
def __call__(self, data):
for key in self.keys:
if "image" in key or "mask" in key:
# 记录原始空间信息
data[key + "_orig_spatial_shape"] = data[key].shape[1:]
data[key + "_orig_affine"] = data[key + "_meta_dict"]["affine"].copy()
return data
# 设置确定性训练(如果需要)
set_determinism(seed=42)
# ====================== 您的自定义转换 ======================
def get_transforms():
"""您的自定义预处理函数"""
deterministic_transforms = Compose([
LoadImaged(keys=["image", "mask"], image_only=False, reader="ITKReader"),
EnsureChannelFirstd(keys=["image", "mask"]),
SyncAffined(keys=["image", "mask"], atol=1e-10), # 使用修复后的自定义转换
Spacingd(keys=["image", "mask"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
CropForegroundd(keys=["image", "mask"], source_key="mask", margin=10),
ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=(64, 64, 64)),
RecordSpatialInfo(keys=["image", "mask"]), # 使用修复后的自定义转换
ScaleIntensityRanged(keys=["image"], a_min=20, a_max=80, b_min=0.0, b_max=1.0, clip=True),
EnsureTyped(keys=["image", "mask"], data_type="tensor"),
], map_items=True, overrides={"allow_missing_keys": True})
augmentation_transforms = Compose([
RandFlipd(keys=["image", "mask"], prob=0.2, spatial_axis=[0, 1, 2]),
RandAffined(
keys=["image", "mask"],
prob=0.3,
rotate_range=(-0.2, 0.2),
scale_range=(0.8, 1.2),
shear_range=(-0.1, 0.1, -0.1, 0.1, -0.1, 0.1),
translate_range=(5, 5, 5),
mode=("bilinear", "nearest"),
padding_mode="border",
spatial_size=(64, 64, 64)
),
Lambdad(keys=["label"], func=lambda x: torch.tensor(x, dtype=torch.long))
])
return deterministic_transforms, augmentation_transforms
# ====================== Grad-CAM 核心类 ======================
class GradCAM:
"""Grad-CAM 实现类"""
def __init__(self, model: nn.Module, target_layer: nn.Module, use_cuda: bool = False):
"""
初始化Grad-CAM
参数:
model: 预训练的PyTorch模型
target_layer: 目标卷积层 (用于计算Grad-CAM)
use_cuda: 是否使用GPU
"""
self.model = model
self.target_layer = target_layer
self.use_cuda = use_cuda
self.activations = None
self.gradients = None
# 注册钩子捕获激活和梯度
self._register_hooks()
def _register_hooks(self):
"""注册前向和后向钩子"""
def forward_hook(module, input, output):
self.activations = output
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
# 将钩子附加到目标层
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def forward(self, input_tensor: torch.Tensor, target_class_index: int = None) -> Tuple[np.ndarray, torch.Tensor]:
"""
计算Grad-CAM热力图
参数:
input_tensor: 输入CT图像张量 [1, C, D, H, W]
target_class_index: 目标类别索引 (默认为最高概率类别)
返回:
heatmap: Grad-CAM热力图 [D, H, W]
output: 模型输出概率
"""
# 确保模型在正确的设备上
if self.use_cuda:
input_tensor = input_tensor.cuda()
# 前向传播
output = self.model(input_tensor)
# 如果没有指定目标类别,使用预测的最高概率类别
if target_class_index is None:
target_class_index = torch.argmax(output).item()
# 创建one-hot向量
one_hot = torch.zeros_like(output)
one_hot[0][target_class_index] = 1
# 反向传播
self.model.zero_grad()
output.backward(gradient=one_hot)
# 获取激活和梯度
activations = self.activations.detach().cpu()
gradients = self.gradients.detach().cpu()
# 计算梯度权重
weights = torch.mean(gradients, dim=(3, 4, 5), keepdim=True) # 3D图像
# 计算热力图
heatmap = torch.sum(weights * activations, dim=1, keepdim=True)
heatmap = torch.relu(heatmap) # ReLU激活
# 归一化
heatmap -= heatmap.min()
heatmap /= (heatmap.max() - heatmap.min() + 1e-12)
return heatmap.squeeze(), output.detach().cpu()
# ====================== 辅助函数 ======================
def apply_ct_window(image: np.ndarray, window_center: float, window_width: float) -> np.ndarray:
"""
应用CT窗宽窗位调整
参数:
image: 原始CT数据数组
window_center: 窗位 (HU值)
window_width: 窗宽 (HU值)
返回:
调整后的图像数组 (0-1范围)
"""
window_min = window_center - window_width / 2
window_max = window_center + window_width / 2
image = np.clip(image, window_min, window_max)
return (image - window_min) / (window_max - window_min)
# ====================== 主功能函数 ======================
def generate_gradcam_for_ct(
model: nn.Module,
ct_path: str,
mask_path: str,
output_dir: str,
device: torch.device,
transforms: Compose,
target_class_index: int = None) -> Tuple[bool, str]:
"""
完整的Grad-CAM生成流程
参数:
model: 预训练模型
ct_path: CT图像路径 (.nii.gz, .mha, .mhd等)
mask_path: 掩膜路径
output_dir: 输出目录
device: 计算设备 (torch.device)
transforms: 预处理转换
target_class_index: 目标类别索引 (可选)
返回:
(success, message)元组
"""
try:
# 1. 加载原始图像和掩膜 (预处理前的原始数据)
ct_img = nib.load(ct_path)
original_ct_data = ct_img.get_fdata()
original_affine = ct_img.affine
mask_img = nib.load(mask_path)
mask_data = mask_img.get_fdata()
# 创建输入字典
input_dict = {
"image": ct_path,
"mask": mask_path
}
# 2. 应用您的预处理转换
preprocessed_data = transforms(input_dict)
# 3. 准备模型输入
ct_tensor = preprocessed_data["image"].unsqueeze(0).to(device)
# 4. 选择目标层 (对于ResNet50)
# 最后一个卷积层 - 取决于模型结构
if hasattr(model, 'layer4') and hasattr(model.layer4[-1], 'conv3'):
# 对于ResNet50,最后一个卷积层在layer4的第三个块
target_layer = model.layer4[-1].conv3
else:
# 默认使用模型的最后一个卷积模块
layers = [module for module in model.modules() if isinstance(module, nn.Conv3d)]
target_layer = layers[-1] if layers else None
if target_layer is None:
raise ValueError("找不到目标卷积层用于Grad-CAM")
print(f"⚠️ 使用最后检测到的卷积层: {target_layer}")
# 5. 生成Grad-CAM热力图
use_cuda = device.type == 'cuda'
cam = GradCAM(model=model, target_layer=target_layer, use_cuda=use_cuda)
heatmap, output = cam(ct_tensor, target_class_index)
# 如果没有指定目标类别,获取预测的最高概率类别
if target_class_index is None:
target_class_index = torch.argmax(output).item()
class_prob = torch.softmax(output, dim=1)[0][target_class_index].item()
# 6. 将热力图缩放到原始图像分辨率
heatmap_np = heatmap.numpy()
original_shape = original_ct_data.shape
# 使用空间信息恢复原始尺寸(如果可用)
if "image_orig_spatial_shape" in preprocessed_data:
orig_shape = preprocessed_data["image_orig_spatial_shape"]
print(f"使用记录的空间信息恢复尺寸: {orig_shape}")
else:
orig_shape = original_shape
print("⚠️ 未找到空间信息记录,使用原始图像尺寸")
# 热力图插值回原始尺寸
heatmap_resized = torch.nn.functional.interpolate(
torch.tensor(heatmap_np).unsqueeze(0).unsqueeze(0),
size=orig_shape,
mode='trilinear',
align_corners=False
).squeeze().numpy()
# 应用掩膜
heatmap_resized = heatmap_resized * mask_data
# 7. 预处理用于显示的原始CT图像
# 对脑CT使用标准的脑窗设置
brain_window_center = 40 # HU
brain_window_width = 80 # HU
display_ct = apply_ct_window(original_ct_data, brain_window_center, brain_window_width)
# 8. 创建可视化图像
os.makedirs(output_dir, exist_ok=True)
# 获取切片索引 (处理3D数据)
if display_ct.ndim == 3:
z_size = display_ct.shape[2]
slice_indices = [z_size // 4, z_size // 2, 3 * z_size // 4]
else:
# 如果是2D数据
slice_indices = [0]
# 为每个切片生成可视化
for idx, slice_idx in enumerate(slice_indices):
# 对于3D数据提取切片
if display_ct.ndim == 3:
display_slice = display_ct[:, :, slice_idx]
heatmap_slice = heatmap_resized[:, :, slice_idx]
mask_slice = mask_data[:, :, slice_idx]
else:
display_slice = display_ct
heatmap_slice = heatmap_resized
mask_slice = mask_data
# 创建彩色热力图
heatmap_color = plt.cm.jet(heatmap_slice)[..., :3]
# 创建CT图像的RGB版本
ct_rgb = np.stack([display_slice] * 3, axis=-1)
# 叠加热力图
alpha = 0.5
superimposed_img = ct_rgb * (1 - alpha) + heatmap_color * alpha
# 可视化
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# 原始CT
axes[0].imshow(display_slice, cmap='gray')
axes[0].set_title(f'原始CT (切片 {slice_idx})')
axes[0].axis('off')
# 掩膜
axes[1].imshow(mask_slice, cmap='gray')
axes[1].set_title(f'脑部掩膜')
axes[1].axis('off')
# 热力图
axes[2].imshow(heatmap_slice, cmap='jet')
axes[2].set_title(f'Grad-CAM热力图\n类别: {target_class_index},概率: {class_prob:.3f}')
axes[2].axis('off')
# 叠加结果
axes[3].imshow(superimposed_img)
axes[3].set_title('Grad-CAM叠加')
axes[3].axis('off')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'gradcam_result_slice{idx}.png'), bbox_inches='tight', dpi=150)
plt.close()
# 9. 保存热力图结果为NIFTI文件
heatmap_img = nib.Nifti1Image(heatmap_resized, original_affine)
nib.save(heatmap_img, os.path.join(output_dir, 'gradcam_heatmap.nii.gz'))
# 10. 创建结果报告
with open(os.path.join(output_dir, 'report.txt'), 'w') as f:
f.write(f"Grad-CAM生成报告\n")
f.write(f"CT文件: {ct_path}\n")
f.write(f"掩膜文件: {mask_path}\n")
f.write(f"目标类别索引: {target_class_index}\n")
f.write(f"类别概率: {class_prob:.4f}\n")
f.write(f"最大激活值: {np.max(heatmap_resized):.4f}\n")
f.write(f"平均激活值: {np.mean(heatmap_resized):.4f}\n")
return True, "Grad-CAM生成成功"
except Exception as e:
import traceback
traceback.print_exc()
return False, f"处理失败: {str(e)}"
# ====================== 模型加载函数 ======================
def load_model(model_path: str, device: torch.device, num_classes: int = 2) -> nn.Module:
"""
加载预训练模型
参数:
model_path: 模型权重文件路径
device: 计算设备
num_classes: 分类类别数
返回:
加载权重的模型
"""
# 初始化ResNet50模型 (3D版本)
model = monai.networks.nets.resnet50(
pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=num_classes
).to(device)
# 模型权重加载
try:
# 注意:严格模式设为False以兼容可能的结构变化
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)
print("✅ 模型权重加载成功(忽略不匹配的层)")
except Exception as e:
print(f"❌ 权重加载失败: {str(e)}")
# 后备方案:部分初始化权重
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
model.fc.apply(init_weights)
print("⚠️ 使用初始化权重替代")
# 设置模型为评估模式
model.eval()
return model
# ====================== 主执行流程 ======================
if __name__ == "__main__":
# 1. 配置路径和设备
ct_path = "D:/monaisj/data/ct_image.nii.gz" # 替换为实际路径
mask_path = "D:/monaisj/data/mask.nii.gz" # 替换为实际路径
model_path =
阅读全文
相关推荐
















