Traceback (most recent call last): File "D:\CNN\多标签分类\qypt.py", line 299, in <module> window = ModelComparisonApp() ^^^^^^^^^^^^^^^^^^^^ File "D:\CNN\多标签分类\qypt.py", line 104, in __init__ self.model_wrapper = ModelWrapper() ^^^^^^^^^^^^^^ File "D:\CNN\多标签分类\qypt.py", line 29, in __init__ "ResNet50": self._init_resnet50(), ^^^^^^^^^^^^^^^^^^^^^ File "D:\CNN\多标签分类\qypt.py", line 45, in _init_resnet50 model.load_state_dict(torch.load('resnet_model.pth', map_location='cpu')) File "C:\Users\w2514\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\module.py", line 2581, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.runni
时间: 2025-05-25 14:50:53 浏览: 30
### 解决方案
当遇到 `RuntimeError: Error(s) in loading state_dict for ResNet` 错误时,通常是由于以下几个原因造成的:
#### 1. **缺失键(Missing Keys)**
如果模型的架构定义与保存的状态字典不一致,则会出现缺少某些参数的情况。例如,在加载预训练权重时,可能未初始化部分层。
解决方案可以分为以下几种情况:
- 如果是自定义网络扩展了原始的 ResNet 架构,需手动初始化这些新增加的部分。
- 使用严格模式 (`strict=False`) 来忽略那些不存在的键[^1]。
```python
resnet.load_state_dict(torch.load(PHOTO_RESNET, map_location=torch.device('cpu')), strict=False)
```
#### 2. **数据并行模块的影响**
在使用 `torch.nn.DataParallel` 训练模型后,会自动给所有的张量名称前加上 `"module."` 前缀。因此,直接加载这样的状态字典到单 GPU 或 CPU 上运行的模型可能会失败。
可以通过重命名键来移除多余的前缀[^4]:
```python
checkpoint = torch.load(PHOTO_RESNET, map_location=torch.device('cpu'))
new_checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
resnet.load_state_dict(new_checkpoint)
```
#### 3. **全连接层尺寸不匹配**
另一个常见问题是最后分类器层(即 `fc` 层)的输入输出维度不同步。比如原模型可能是针对 ImageNet 数据集设计的,其最后一层有 1000 类别;而当前任务只需要少量类别。
此时应重新定义该层,并确保新模型结构与旧权重兼容[^5]:
```python
class CustomResNet(nn.Module):
def __init__(self, num_classes=6):
super(CustomResNet, self).__init__()
base_model = models.resnet50(pretrained=True)
# 替换最后一层为新的分类头
in_features = base_model.fc.in_features
base_model.fc = nn.Linear(in_features, num_classes)
self.model = base_model
def forward(self, x):
return self.model(x)
custom_resnet = CustomResNet(num_classes=6).to(device)
custom_resnet.load_state_dict(torch.load(PHOTO_RESNET, map_location=device), strict=False)
```
以上方法能够有效处理大多数情况下关于 `state_dict` 加载的问题。
---
###
阅读全文
相关推荐



















