【机器学习】 详解Vision Transformer
时间: 2025-05-12 12:40:36 浏览: 17
### Vision Transformer 的模型架构与原理
Vision Transformer (ViT) 是一种基于 Transformer 架构的深度学习方法,用于处理图像分类和其他计算机视觉任务。它通过将图像分割成固定大小的小块(patches),并将其映射到一维向量序列来模拟自然语言处理中的词嵌入过程[^2]。
#### 图像预处理
在 ViT 中,输入图像是被划分为多个不重叠的 patches,这些 patches 被视为 token 并进一步线性投影为高维空间中的表示向量。为了保持位置信息,在此过程中会加入可训练的位置编码[^3]。
```python
import torch
from einops import rearrange, repeat
def patch_and_flatten(image_tensor, patch_size=16):
"""
将图片切分成patch,并展平每个patch。
参数:
image_tensor: 输入图像张量,形状为 [batch_size, channels, height, width]
patch_size: 单个patch的高度和宽度
返回:
展开后的patch张量,形状为 [batch_size * num_patches, patch_dim]
"""
batch_size, _, img_height, img_width = image_tensor.shape
assert img_height % patch_size == 0 and img_width % patch_size == 0, \
f"Image dimensions must be divisible by the patch size {patch_size}"
# 切分图像为 patches
patches = rearrange(
image_tensor,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=patch_size,
p2=patch_size
)
return patches
```
上述代码展示了如何利用 `einops` 库高效地完成图像切割操作。
#### Transformer Encoder 结构
Transformer 编码器由多层组成,每一层都包含了自注意力机制(Self-Attention Mechanism)以及前馈神经网络(Feed Forward Network)。这种设计允许模型捕捉全局依赖关系而不仅仅局限于局部特征提取[^1]。
#### 自注意力机制
自注意力模块计算的是不同 tokens 之间的相互作用强度,从而动态调整权重分配给各个部分的重要性程度。其核心公式如下所示:
\[
\text{Attention}(Q,K,V)= \text{softmax}(\frac{QK^T}{\sqrt{d_k}}) V
\]
其中 \( Q \),\( K \), 和 \( V \) 分别代表查询矩阵(Query Matrix),键值(Key Matrix),和值(Value Matrix)。
#### 训练策略
由于传统卷积神经网络(Convolutional Neural Networks,CNNs)擅长从小规模数据集中泛化良好特性,因此当直接应用标准 Transformers 处理大规模未标注数据时可能会遇到困难。为此研究者提出了多种改进方案比如引入更强正则项或者采用混合模式(hybrid models)等方式提升性能表现。
#### 实现细节
以下是使用 PyTorch 实现的一个简单版本 Vision Transformer:
```python
class MultiHeadedSelfAttention(torch.nn.Module):
def __init__(self,dim,num_heads=8):
super().__init__()
self.num_attention_heads=num_heads
head_dim=int(dim/num_heads)
self.scale=head_dim**(-0.5)
self.qkv=torch.nn.Linear(dim,dim*3,bias=False)
self.proj_out=torch.nn.Linear(dim,dim)
def forward(self,x):
q,k,v=self.qkv(x).chunk(3,-1)
q=q.reshape(q.size(0),q.size(1),self.num_attention_heads,-1).permute(0,2,1,3)
k=k.reshape(k.size(0),k.size(1),self.num_attention_heads,-1).permute(0,2,1,3)
v=v.reshape(v.size(0),v.size(1),self.num_attention_heads,-1).permute(0,2,1,3)
energy=([email protected](-2,-1))*self.scale
attention_distribution=energy.softmax(dim=-1)
out=(attention_distribution @ v).transpose(2,1).reshape_as(x)
return self.proj_out(out)
class TransformerBlock(torch.nn.Module):
def __init__(self,dim,mlp_ratio=4.,dropout_rate=0.,num_heads=8):
super().__init__()
hidden_features=int(mlp_ratio*dim)
self.norm1=torch.nn.LayerNorm(dim)
self.attn=MultiHeadedSelfAttention(dim=dim,num_heads=num_heads)
self.drop_path=torch.nn.Dropout(dropout_rate)if dropout_rate>0 else torch.nn.Identity()
self.norm2=torch.nn.LayerNorm(dim)
self.mlp=torch.nn.Sequential(
torch.nn.Linear(dim,hidden_features),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(hidden_features,dim),
torch.nn.Dropout(dropout_rate))
def forward(self,x):
x=x+self.drop_path(self.attn(self.norm1(x)))
x=x+self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(torch.nn.Module):
"""Basic Vision Transformer architecture."""
def __init__(
self,image_size=224,patch_size=16,in_chans=3,n_classes=1000,
embed_dim=768,depth=12,num_heads=12,mlp_ratio=4.,
dropout_rate=0.
):
super(VisionTransformer,self).__init__()
n_patches=(image_size//patch_size)**2
self.patch_embedding=torch.nn.Conv2d(in_channels=in_chans,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)
self.position_embeddings=torch.nn.Parameter(torch.zeros(1,n_patches+1,embed_dim))
self.cls_token=torch.nn.Parameter(torch.zeros(1,1,embed_dim))
self.dropout=torch.nn.Dropout(p=dropout_rate)
dpr=[x.item()forxin torch.linspace(0,dropout_rate,depth)]
self.blocks=torch.nn.ModuleList([
TransformerBlock(embed_dim=embed_dim,mlp_ratio=mlp_ratio,dropout_rate=dpr[i],num_heads=num_heads)
for iin range(depth)])
self.norm=torch.nn.LayerNorm(embed_dim)
self.head=torch.nn.Linear(embed_dim,n_classes)
def forward(self,x):
B=C,H,W=x.shape[0],x.shape[-2],x.shape[-1]
x=self.patch_embedding(x).flatten(start_dim=2,end_dim=3).transpose(1,2)#B,N,D
cls_tokens=self.cls_token.expand(B,-1,-1)
x=torch.cat((cls_tokens,x),dim=1)
x+=self.position_embeddings[:,:x.size(1),:]
x=self.dropout(x)
for blk in self.blocks:x=blk(x)
x=self.norm(x)
return self.head(x[:,0])
```
阅读全文
相关推荐















