swin transformer block 计算公式
时间: 2025-01-24 15:05:50 浏览: 59
### Swin Transformer Block 的计算公式
Swin Transformer 中的每一个 block 是由多个子模块构成,这些子模块共同协作来实现高效的特征提取。具体来说,一个标准的 Swin Transformer block 主要包含两个核心部分:多头窗口自注意力(W-MSA 或 SW-MSA)和一个多层感知机(MLP)。此外,还加入了残差连接以及 LayerNorm 层以促进训练稳定性。
#### 多头窗口自注意力 (Multi-head Window Self-attention, W-MSA)
对于输入张量 \(X \in R^{H\times W \times C}\),其中 H 和 W 表示高度和宽度,C 代表通道数:
1. 首先将输入划分为不重叠的局部窗口;
2. 对于每个窗口内的 token 序列执行常规的多头自注意操作;
\[ Q = XW_Q,\ K=XW_K,\ V=XW_V \]
这里 \(Q\),\(K\) 和 \(V\) 分别表示查询、键和值矩阵,而 \(W_Q\), \(W_K\)\(W_V\) 则是可学习参数权重矩阵[^1]。
接着定义相对位置偏置项 B:
\[ A_{ij}=\frac{Q_i(K_j)^T}{\sqrt{d_k}}+B_{i-j}, i,j=1,...,n \]
最后得到加权求和后的输出 O :
\[ Z = softmax(A)V \]
为了提升表达能力并减少冗余信息的影响,在奇数层会采用移位窗口策略(SW-MSA)[^4]。
#### MLP 子层
紧接着上述自注意力机制之后的是两层全连接网络组成的前馈神经网络 FFN ,其形式如下所示:
\[ GELU(W_1Z+b_1) \]
```python
import torch.nn as nn
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
...
```
整个block 结构还包括了Layer Normalization 和 Skip Connection 来增强梯度传播效果[^3]。
阅读全文
相关推荐


















