Iwin Transformer模块
论文地址: https://arxiv.org/abs/2507.18405
讲解视频: B站【大嘴带你水论文】
深度交流Q裙:607618759
详细代码在文章下面~~
1、作用
Iwin Transformer是一个革命性的无位置编码层次化视觉Transformer模块,专门解决传统Swin Transformer需要两个连续块才能实现全局信息交换的限制。该模块通过创新的交错窗口注意力机制和深度可分离卷积的协同工作,在单个Transformer块内就能实现全局信息交换,计算成本仅为两个连续Swin块的一半。Iwin Transformer可以直接从低分辨率微调到高分辨率而不会出现性能下降,在ImageNet-1K上达到87.4%的top-1准确率,同时支持无缝替换生成模型中的标准注意力模块。
2、机制
1、交错窗口注意力:
Iwin的核心创新是交错窗口注意力机制,通过Reshape-Transpose-Reshape(RTR)操作系统性地重新排列特征序列为交错模式。与标准窗口注意力将特征图均匀分割为不重叠窗口不同,IWA在窗口分割前重新排列特征图,使来自不同区域的token被分组到同一窗口中进行注意力计算。这种设计确保了具有相同(i mod Hg, j mod Wg)的token被分组到同一窗口,实现了跨区域的直接信息交换。
2、深度可分离卷积协同:
深度可分离卷积与交错窗口注意力形成互补的信息处理单元。卷积建立了某些通过交错窗口注意力无法建立的token连接,同时提供隐式位置信息。这种混合设计不仅增强了特征表示,还减少了对显式位置编码的依赖,使Iwin成为无位置编码的Transformer。
3、独特优势
1、单块全局信息交换:
与Swin Transformer需要两个连续块(常规窗口+移位窗口)才能实现全局信息交换不同,Iwin Transformer在单个块内就能实现全局信息交换。这种设计消除了复杂的掩码操作,减少了计算冗余,使得相同的全局感受野只需要约一半的计算成本,特别适合高分辨率视觉应用。
2、无位置编码的可扩展性:
Iwin Transformer完全不需要显式位置编码,确保了在不同输入分辨率下的强大可扩展性,不会出现性能下降。这克服了Swin Transformer v2中需要复杂的对数空间连续位置偏置(Log-CPB)来处理高分辨率输入的限制。模型可以直接从224²预训练状态微调到384²、512²甚至1024²分辨率。
3、生成模型友好设计:
由于Iwin的统一模块可以作为独立模块无缝替换标准注意力模块,而不影响后续与文本条件的交叉注意力操作,这使得Iwin特别适合现代文本到图像扩散模型。相比之下,Swin由于其两块依赖性,在AIGC时代的文本条件注入方面存在明显劣势。
4、卓越的跨分辨率性能:
Iwin模型展现出优异的跨分辨率微调能力。例如,Iwin-S从224²微调到384²可获得0.9%的准确率提升(从83.4%到84.3%),微调到512²可获得1.0%的提升(到84.4%)。这种能力归功于交错窗口注意力和深度可分离卷积的协同作用,使模型能够享受全局视野而不依赖位置编码。
5、理论保证的全局连通性:
Iwin提供了严格的数学证明,当满足KM ≥ max(H,W)条件时,任意两个位置(i₁,j₁)和(i₂,j₂)之间都存在信息交换路径。这种理论保证确保了模型的全局建模能力,为其优异性能提供了坚实的理论基础。
4、代码
gitcode地址:https://gitcode.com/2301_80107842/research
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def rearrange(x, Hgroup, Wgroup):
"""RTR操作:重排特征图为交错模式"""
B, H, W, C = x.shape
# 第一步:沿高度方向重排
x = x.reshape(B, -1, Hgroup, W, C).transpose(1, 2)
x = x.reshape(B, -1, W, C)
# 第二步:沿宽度方向重排
x = x.reshape(B, H, -1, Wgroup, C).transpose(2, 3)
x = x.reshape(B, H, -1, C)
return x
def restore(x, Hgroup, Wgroup):
"""逆RTR操作:恢复原始特征图排列"""
B, H, W, C = x.shape
# 第一步:沿宽度方向恢复
x = x.reshape(B, H, Wgroup, -1, C).transpose(2, 3)
x = x.reshape(B, H, -1, C)
# 第二步:沿高度方向恢复
x = x.reshape(B, Hgroup, -1, W, C).transpose(1, 2)
x = x.reshape(B, -1, W, C)
return x
def window_partition(x, window_size):
"""将特征图分割为窗口"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""将窗口合并回特征图"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""交错窗口注意力模块(基于原项目精简版)"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # 窗口大小
self.num_heads = num_heads # 注意力头数
head_dim = dim // num_heads
self.scale = head_dim ** -0.5 # 缩放因子
# QKV线性变换
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
前向传播
Args:
x: 输入特征 (num_windows*B, N, C)
Returns:
输出特征 (num_windows*B, N, C)
"""
B_, N, C = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B_, num_heads, N, head_dim)
# 计算注意力(支持Flash Attention加速)
if hasattr(F, 'scaled_dot_product_attention'):
# 使用PyTorch 2.0的Flash Attention
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
scale=self.scale
)
else:
# 传统注意力计算
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = attn @ v
# 重塑输出
x = x.transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DepthwiseConv2d(nn.Module):
"""深度可分离卷积模块"""
def __init__(self, dim, kernel_size=3, stride=1, padding=1):
super().__init__()
self.depthwise = nn.Conv2d(dim, dim, kernel_size, stride, padding, groups=dim)
self.dim = dim
self.kernel_size = kernel_size
def forward(self, x):
return self.depthwise(x)
class IwinTransformerBlock(nn.Module):
"""Iwin Transformer块"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, mlp_ratio=4.,
qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution # (H, W)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
# 如果窗口大小大于输入分辨率,则不分割窗口
if min(self.input_resolution) <= self.window_size:
self.window_size = min(self.input_resolution)
# LayerNorm和注意力
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=self.window_size, num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# 深度可分离卷积(条件性添加)
if max(self.input_resolution) / self.window_size <= 1:
self.dwconv = None # 小分辨率时不需要卷积
else:
self.dwconv = DepthwiseConv2d(dim, kernel_size=3, padding=1)
# DropPath和MLP
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
"""
前向传播(基于原项目实现)
Args:
x: 输入特征 (B, L, C) 其中 L = H * W
Returns:
输出特征 (B, L, C)
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "输入特征尺寸错误"
# 保存残差连接
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 深度可分离卷积分支(如果存在)
if self.dwconv is not None:
x_conv = x.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
x_conv = self.dwconv(x_conv).permute(0, 2, 3, 1) # (B, H, W, C)
# 交错窗口注意力分支
# 步骤1: RTR重排
x = rearrange(x, H // self.window_size, W // self.window_size)
# 步骤2: 窗口分割
x_windows = window_partition(x, self.window_size) # (nW*B, window_size, window_size, C)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # (nW*B, N, C)
# 步骤3: 窗口注意力
attn_windows = self.attn(x_windows) # (nW*B, N, C)
# 步骤4: 窗口合并
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x = window_reverse(attn_windows, self.window_size, H, W) # (B, H, W, C)
# 步骤5: 恢复原始排列
x = restore(x, H // self.window_size, W // self.window_size) # (B, H, W, C)
# 合并注意力和卷积结果
if self.dwconv is not None:
x = x + x_conv # 残差连接
# 转换回序列格式
x = x.view(B, H * W, C)
# 第一个残差连接
x = shortcut + self.drop_path(x)
# MLP分支
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Mlp(nn.Module):
"""多层感知机(基于原项目)"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DropPath(nn.Module):
"""随机深度"""
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
# 测试代码
if __name__ == '__main__':
# 创建测试输入:(批次大小, 序列长度, 特征维度)
# 对于56x56的特征图,序列长度为56*56=3136
H, W = 56, 56
input_tensor = torch.rand(2, H * W, 96) # (B, L, C)
# 创建Iwin Transformer块
model = IwinTransformerBlock(
dim=96,
input_resolution=(H, W),
num_heads=3,
window_size=7
)
# 前向传播
output = model(input_tensor)
# 打印输入输出尺寸
print('输入尺寸:', input_tensor.size())
print('输出尺寸:', output.size())
print('参数数量:', sum(p.numel() for p in model.parameters()))