机器学习,详解vision transformer
时间: 2025-06-30 09:16:58 浏览: 8
### ViT 的工作原理
Vision Transformer (ViT) 是一种将自然语言处理中广泛使用的 Transformer 模型应用于图像任务的架构。其核心思想是将图像分割为小块(patches),并将这些块转换为向量序列,然后输入到 Transformer 编码器中进行处理[^1]。
#### 图像分块与嵌入(Image Patching and Embedding)
在 ViT 中,图像被划分为固定大小的小块(例如 16x16 像素)。每个图像块通过一个线性变换(通常是一个卷积层)映射到一个低维向量空间中,形成所谓的“patch embeddings”。这种处理方式类似于 NLP 中的词嵌入(word embedding)过程。此外,为了保留位置信息,ViT 还引入了一个可学习的位置编码(positional encoding),并与 patch embeddings 相加后输入到 Transformer 编码器中。
```python
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
x = self.proj(x) # (n_samples, embed_dim, n_patches**0.5, n_patches**0.5)
x = x.flatten(2) # (n_samples, embed_dim, n_patches)
x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)
return x
```
#### Transformer 编码器(Transformer Encoder)
ViT 使用标准的 Transformer 编码器结构,由多头自注意力机制(Multi-Head Self-Attention, MHSA)和前馈网络(Feed-Forward Network, FFN)组成。MHSA 允许模型关注图像中的不同区域并捕捉长距离依赖关系。FFN 则用于对每个 token 进行非线性变换。此外,ViT 在输入序列中添加了一个特殊的分类 token([CLS] token),该 token 最终会被用于图像分类任务。
```python
class TransformerEncoder(nn.Module):
def __init__(self, embed_dim=768, depth=12, num_heads=8, mlp_ratio=4., qkv_bias=False, drop_rate=0.):
super().__init__()
self.blocks = nn.ModuleList([
Block(embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate)
for _ in range(depth)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Block(nn.Module):
def __init__(self, embed_dim=768, num_heads=8, mlp_ratio=4., qkv_bias=False, drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = Mlp(in_features=embed_dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU, drop=drop)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
```
#### 分类头(Classification Head)
经过 Transformer 编码器处理后,[CLS] token 的输出将被送入一个全连接层(有时包含多个层),以生成最终的分类结果。这一部分通常被称为分类头(classification head)。
```python
class ClassificationHead(nn.Module):
def __init__(self, embed_dim=768, num_classes=1000):
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
return self.head(x[:, 0])
```
### 实现细节
除了上述基本结构外,ViT 还有一些重要的实现细节需要注意:
- **预训练与微调**:由于 ViT 需要大量数据才能发挥最佳性能,因此通常会在大规模数据集(如 ImageNet)上进行预训练,然后再在特定任务上进行微调。
- **位置编码**:ViT 使用可学习的位置编码来保留图像块的空间位置信息。这种编码可以是绝对位置编码或相对位置编码。
- **混合架构**:某些变体(如 Convolutional Vision Transformer, CvT)结合了 CNN 和 Transformer 的优点,在早期阶段使用卷积层提取局部特征,再用 Transformer 处理全局信息。
###
阅读全文
相关推荐















