vision transformer模型图
时间: 2025-01-03 17:34:30 浏览: 53
### Vision Transformer (ViT) 模型架构概述
Vision Transformer (ViT) 将图像分割成固定大小的多个patch序列,这些patch被线性嵌入并附加位置编码后输入到标准的Transformer Encoder中[^1]。整个过程保持了原始Transformer的设计理念,即通过自注意力机制来捕捉全局依赖关系。
#### 图像处理流程
- **Patch Embedding**: 输入图像首先被划分为不重叠的小块(patch),每个patch会被展平成一维向量并通过线性变换映射至指定维度的空间内。
- **Position Encoding**: 为了保留空间信息,在得到的patch embedding上加上可学习的位置编码。
```python
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
```
#### 主干网络 - Transformer Encoders
核心部分由一系列堆叠的标准Transformer Encoder组成,每一层都包含了多头自注意模块(Multi-head Self-Attention Layer) 和前馈神经网络(Feed Forward Network)[^2]:
```python
import torch.nn as nn
from transformers import BertConfig, BertModel
config = BertConfig(
hidden_size=embed_dim,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
intermediate_size=intermediate_size,
)
transformer_encoder = BertModel(config)
```
#### 输出阶段
最终输出通常会经过一层全连接层转换为目标类别数目的logits值用于分类任务。对于其他下游任务,则可以根据需求调整最后一层的具体实现方式[^3]。

阅读全文
相关推荐


















