swin transformer掩码
时间: 2025-02-17 11:17:19 浏览: 93
### Swin Transformer 中的掩码机制解析
在 Swin Transformer 的架构设计中,为了有效处理局部特征并减少计算复杂度,引入了基于移位窗口的操作。每个窗口内的自注意力操作仅限于该窗口内部的 token 之间,这使得不同窗口之间的交互被阻断。
当执行移位窗口划分时,原始图像会被分割成不重叠的窗口。对于标准窗口划分,在奇数层中保持不变;而在偶数层,则通过移动窗口位置来创建新的窗口组合方式[^1]。这种交替模式确保了相邻窗口间的信息流动。
#### 掩码生成过程
由于窗口间的 token 不应相互影响,因此需要构建相应的 attention mask 来屏蔽掉不属于当前窗口范围内的元素:
1. **初始化 Mask**: 对于每一个窗口大小 \(7 \times 7\) 或其他指定尺寸,会预先定义好一个全零矩阵作为基础模板。
2. **填充 Mask 值**: 当应用移位窗口策略后,某些原本属于同一窗口但在新布局下跨越边界的区域应当被标记出来。这些超出边界的部分将在后续计算过程中被忽略,即设置对应的 mask 项为负无穷大(-inf),从而让 softmax 函数将其权重压缩至接近0。
3. **应用于 Attention 计算**: 在多头自注意模块 Multi-head Self-Attention(MSA) 内部,输入序列经过线性变换投影得到 QKV 后,attention score 将乘上之前准备好的 mask 矩阵。这样可以保证只有来自相同窗口内部节点才能互相作用,而跨窗连接则完全失效。
```python
import torch
from einops import rearrange, repeat
def create_attn_mask(window_size, shift_size):
H, W = window_size
img_mask = torch.zeros((1, H, W, 1))
h_slices = (slice(0, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# Convert to mask matrix used during MHA computation.
mask_windows = rearrange(img_mask.view(1, H//window_size[0], window_size[0], W//window_size[1], window_size[1]), 'n wh ws1 ww ws2 -> n (wh ww) (ws1 ws2)')
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float("-inf")).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
```
此代码片段展示了如何根据给定的 `window_size` 和 `shift_size` 创建用于遮蔽非本地化token间关联性的 attention mask 。具体来说,先构造了一个二维平面内各像素所属分区编号的地图 (`img_mask`) ,再依据其转换成为适用于批量运算的形式(`mask_windows`) 并最终形成完整的 attention mask 表达形式。
阅读全文
相关推荐


















