VIT
时间: 2025-05-29 22:04:19 浏览: 17
### VIT 模型概述
Vision Transformer (ViT) 是一种基于Transformer架构的视觉模型,最初由 Google 提出。它的核心思想是将图像分割为一系列固定的 patches,并将其视为 token 序列输入到 Transformer 中处理[^3]。
#### 输入处理
在 ViT 的设计中,输入图像是被划分为多个固定大小的小块(patches),这些小块会被展平并映射为一维向量,随后通过线性投影层转换为 patch embeddings[^3]。为了保留空间信息,在此过程中还会加入位置编码(Positional Encoding)。此外,还引入了一个特殊的类别标记(Class Token),用于最终的分类任务。
#### 架构组成
ViT 主要由以下几个部分构成:
1. **Patch Embedding**: 图像被切分成若干个不重叠的小块,每个小块经过线性变换成为高维度向量。
2. **Positional Encoding**: 添加的位置编码有助于模型理解各个 patch 在原始图像中的相对关系。
3. **Transformer Encoder Layers**: 多层堆叠的标准 Transformer 编码器结构负责捕捉全局依赖性和提取高级特征[^4]。
4. **MLP Head**: 经过所有 transformer 层之后,取出 class token 对应的部分送入全连接层完成最后的任务预测,比如分类[^4]。
#### 微调与应用
当使用预训练好的 ViT 模型执行特定领域内的任务时,通常会采用微调策略。这包括保持大部分原有参数不变而仅调整新增加的最后一层或多层权重以适配目标任务需求[^1]。例如,在进行图像分类任务时可以替换掉原有的头部网络换成适合自己数据集数量级的新一层,并重新训练这部分参数[^2]。
以下是实现一个简单的基于 PyTorch 的 CIFAR-10 数据集上的 ViT 分类的例子:
```python
import torch
from transformers import ViTForImageClassification
import os
import torch.nn as nn
model = ViTForImageClassification.from_pretrained('/path/to/pretrained/vit/')
num_classes = 10 # For CIFAR-10 dataset
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
save_dir = './saved_models/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
```
阅读全文
相关推荐

















