Swin Transformer实现
时间: 2025-05-10 14:26:52 浏览: 21
### Swin Transformer 的实现方法
Swin Transformer 是一种基于分层结构和移位窗口机制的视觉 Transformer 模型,其设计旨在提升模型性能并优化计算效率。以下是其实现的核心组件及其代码示例。
#### 核心模块解析
1. **Window Partition 和 Window Reverse**
这两个函数用于将输入特征图划分为不重叠的小窗口,并在处理完成后重新组合回原形状。这是 Swin Transformer 中移位窗口机制的关键操作之一[^2]。
```python
import torch
def window_partition(x, window_size):
"""
将特征图划分成多个窗口。
参数:
x: 输入张量 (B, H, W, C)
window_size: 窗口大小
返回:
windows: 划分后的窗口张量 (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
将窗口重组为完整的特征图。
参数:
windows: 窗口张量 (num_windows*B, window_size, window_size, C)
window_size: 窗口大小
H, W: 原始高度和宽度
返回:
x: 重组后的特征图 (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
```
2. **Shifted Window Mechanism**
移位窗口机制允许不同窗口之间的信息交互,从而增强局部建模能力。该过程涉及周期性地移动窗口位置以及掩码生成以屏蔽跨窗口区域的信息流[^1]。
```python
def create_mask(shift_size, input_resolution, window_size):
"""
创建移位窗口所需的注意力掩码。
参数:
shift_size: 移动步幅
input_resolution: 输入分辨率 (height, width)
window_size: 窗口大小
返回:
attn_mask: 注意力掩码
"""
H, W = input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, ws, ws, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
```
3. **Swin Transformer Block**
单个 Swin Transformer block 结合多头自注意力机制(MSA)、前馈网络(FFN)以及残差连接来构建层次化表示学习的能力[^3]。
```python
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must be smaller than window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = create_mask(self.shift_size, input_resolution=(8*window_size, 8*window_size), window_size=self.window_size)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = x.size(1), x.size(2)
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
```
#### 总结
上述代码展示了 Swin Transformer 的核心组成部分,包括窗口分区与反转、移位窗口机制以及单个 Transformer block 的具体实现方式。这些模块共同构成了 Swin Transformer 的高效架构,使其能够在多种计算机视觉任务中表现出色。
阅读全文
相关推荐


















