transformer五种花分类
时间: 2025-05-13 16:40:29 浏览: 25
### 使用 Transformer 模型实现五类花卉分类
Transformer 是一种基于自注意力机制的架构,最初设计用于自然语言处理任务。然而,在图像领域中,Vision Transformers (ViTs) 和其变体(如 Swin Transformer、CvT 等)已经证明了它们的有效性。以下是使用 PyTorch 实现五类花卉分类的具体方法。
#### 数据准备
为了训练一个能够区分五种不同类别花卉的模型,需要准备好相应的数据集。通常情况下,可以采用公开的数据集,例如 Oxford Flowers 102 或者其他包含多种花卉的小规模数据集[^4]。如果目标仅限于五个类别,则可以从原始数据集中筛选出对应的子集。
```python
import torch
from torchvision import datasets, transforms
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
image_datasets = {x: datasets.ImageFolder(root=f'data/{x}', transform=data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
```
上述代码展示了如何加载并预处理图像数据,确保输入适合后续的 Transformer 架构[^2]。
#### 模型定义
接下来定义一个基于 Transformer 的模型结构。这里推荐使用 Swin Transformer,因为它具有分层的设计特点,更适合高分辨率图像的任务需求。
```python
import timm
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.head.in_features
model.head = torch.nn.Linear(num_ftrs, len(class_names))
```
通过 `timm` 库可以直接获取预先训练好的 Swin Transformer 模型,并调整最后一层全连接层以适应新的分类数目。
#### 训练过程
设置好损失函数与优化器之后即可开始训练循环:
```python
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
return model
model_ft = train_model(model, criterion, optimizer, scheduler, num_epochs=25)
```
此部分实现了标准的监督学习流程,包括前向传播计算预测值、反向传播更新参数以及验证阶段评估性能指标等操作。
---
#### 结果保存与部署
完成训练后可将最终版本导出以便未来重复利用或者在线服务调用:
```python
torch.save(model.state_dict(), './flower_classifier.pth')
```
当再次运行推理时只需重新实例化相同的网络拓扑再载入权重文件即可恢复完整的功能状态.
---
阅读全文
相关推荐
















