pycharm加载VGG
时间: 2025-01-22 10:35:25 浏览: 76
### 如何在 PyCharm 中加载 VGG 模型进行深度学习开发
#### 准备工作环境
为了确保可以在 PyCharm 中顺利运行代码,需先安装必要的库。可以通过命令行工具 pip 安装这些依赖项:
```bash
pip install torch torchvision matplotlib pandas numpy
```
#### 创建并配置 Python 虚拟环境
建议创建一个新的虚拟环境来管理项目的依赖关系。这有助于避免不同项目之间的版本冲突。
#### 编写加载 VGG 模型的脚本
下面是一个完整的例子,展示了如何在 PyCharm 中加载预训练好的 VGG16 模型,并对其进行适当调整以便用于特定的任务[^2]。
```python
import torch
from torchvision import models, transforms
from torch.nn import functional as F
from PIL import Image
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 加载本地保存的权重文件路径
model_path = './checkpoints/vgg16-397923af.pth'
# 初始化未预先训练过的VGG16模型结构
model = models.vgg16(pretrained=False)
# 将之前下载好的参数加载到新初始化的网络中去
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
else:
raise FileNotFoundError("Model weights not found.")
# 如果需要冻结部分层,则可以这样做
for param in model.features.parameters():
param.requires_grad = False
# 对分类器的最后一层做修改以适应新的类别数量
num_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1] # 去掉最后一个线性层
features.extend([torch.nn.Linear(num_features, len(class_names))]) # 添加自定义输出大小的新线性层
model.classifier = torch.nn.Sequential(*features)
# 移动模型至GPU/CPU上执行计算
model.to(device)
def preprocess_image(image_path):
"""图像前处理"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open(image_path)
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
return batch_t
test_img = "./path_to_your_test_image.jpg" # 替换成实际测试图片的位置
input_tensor = preprocess_image(test_img).to(device)
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)[0]
predicted_class_idx = probabilities.argmax().item()
confidence_score = probabilities[predicted_class_idx].item()
print(f"The predicted class index is {predicted_class_idx}, with confidence score of {confidence_score:.4f}.")
```
此段代码实现了如下功能:
- 冻结特征提取部分(`model.features`)中的所有参数,防止其更新;仅允许最后几层参与反向传播优化过程
- 更改原有 `classifier` 层的设计,使其能够针对具体应用场景输出正确的类数目
- 提供了一个简单的函数来进行单张图片预测展示
阅读全文
相关推荐


















