vgg16代码实现图像分类pytorch
时间: 2025-07-04 17:17:32 浏览: 0
在 PyTorch 中实现基于 VGG16 的图像分类代码,可以利用 `torchvision.models` 模块中预定义的 VGG16 模型,并根据具体任务进行微调(fine-tuning)。以下是一个完整的实现流程和示例代码:
### 数据准备
使用 CIFAR-10 数据集作为输入数据源,CIFAR-10 包含 10 个类别的彩色图像,每个类别有 6000 张 32x32 的图像。首先对数据进行预处理和加载:
```python
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
# 数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
### 模型构建与迁移学习
PyTorch 提供了预训练的 VGG16 模型,可以直接使用并修改最后的全连接层以适配当前的任务。
```python
# 加载预训练的 VGG16 模型
model = models.vgg16(pretrained=True)
# 修改最后一层全连接层以适配 CIFAR-10 的 10 个类别
model.classifier[6] = torch.nn.Linear(4096, 10)
# 使用 GPU 如果可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
```
### 训练模型
定义损失函数和优化器,然后训练模型。
```python
import torch.optim as optim
import torch.nn as nn
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(5): # 可根据需要调整 epochs 数量
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
```
### 测试模型
测试模型的准确率。
```python
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total}%')
```
### 注意事项
- **模型性能**:如果希望提升性能,可以调整超参数,如学习率、批量大小,或者增加训练的轮次。
- **扩展性**:此代码可以轻松扩展到其他数据集(如 ImageNet)或网络结构(如 ResNet、Inception 等)[^1]。
阅读全文
相关推荐

















