clip模型如何加入自注意力、交叉注意力、因果注意力
时间: 2025-05-20 09:08:47 浏览: 21
### CLIP中的自注意力、交叉注意力和因果注意力机制
CLIP(Contrastive Language–Image Pre-training)是一种多模态预训练模型,它通过联合优化图像和文本嵌入空间来学习跨模态表示。在CLIP中,Transformer架构被广泛用于处理序列数据,而自注意力(Self-Attention)、交叉注意力(Cross-Attention)以及因果注意力(Causal Attention)则是其核心组成部分。
#### 自注意力(Self-Attention)
自注意力机制允许模型在同一序列的不同位置之间建立依赖关系。对于CLIP而言,在编码器部分,无论是视觉特征还是文本特征,都可能采用基于Transformer的结构来进行建模。具体来说,自注意力计算如下:
给定输入序列 \(X \in R^{n \times d}\),其中 \(n\) 是序列长度,\(d\) 是维度大小,则可以通过线性变换得到查询向量 \(Q\)、键向量 \(K\) 和值向量 \(V\):
\[ Q = XW_Q, K = XW_K, V = XW_V \]
接着,利用这些矩阵计算注意力权重并加权求和获得输出:
\[ A_{self} = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]
这种形式使得每个位置能够关注到整个序列的信息[^1]。
```python
import torch
from torch import nn
class SelfAttention(nn.Module):
def __init__(self, dim, heads=8, dropout=0.1):
super().__init__()
self.heads = heads
self.scale = (dim // heads) ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
dots = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)
```
#### 交叉注意力(Cross-Attention)
当涉及到两个不同类型的序列时,比如文本和图像之间的交互,就需要用到交叉注意力。在这种情况下,一个序列作为查询(Query),另一个序列则提供键(Key)与值(Value)。例如,在CLIP中,可以将文本看作查询源,而图像是键值对的目标域;反之亦然。
假设我们有两组不同的输入序列 \(A\) 和 \(B\),那么交叉注意力建立了它们间的联系:
\[ A_{cross}(A,B) = softmax\left(\frac{QA_B^T}{\sqrt{d_k}}\right)V_B \]
这里 \(Q_A,K_B,V_B\) 分别代表来自各自领域经过投影后的张量。
```python
class CrossAttention(SelfAttention):
def forward(self, x, context=None):
h = self.heads
q, _, _ = self.to_qkv(x).chunk(3, dim=-1)
context = default(context, x)
c = self.to_kv(context)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), c.chunk(2, dim=-1))
q = rearrange(q, 'b n (h d) -> b h n d', h=h)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
```
#### 因果注意力(Causal Attention)
为了支持生成任务或者时间顺序敏感的应用场景,引入了因果约束条件下的注意力机制——即只允许当前时刻之前的数据参与运算。这通常应用于自然语言生成等领域。实现上只需修改掩码操作即可阻止未来信息泄露。
如果定义了一个下三角形布尔型掩码矩阵 Mask ,尺寸为 [Seq_Len × Seq_Len ] 并且满足仅当 row ≤ col 时取 True 值 。 那么就可以很容易地将其融入标准 Softmax 函数前 :
\[ Dots' = Dots + (-Inf)*(~Mask)\]
这样就实现了所谓的 Causal-Masked Attention .
```python
def causal_attention_mask(size):
mask = (torch.triu(torch.ones(size, size)) == 0).transpose(0, 1)
masked_fill_value = float('-inf')
return mask.masked_fill(mask==False,masked_fill_value)
class CausalSelfAttention(SelfAttention):
def __init__(self,*args,**kwargs):
super(CausalSelfAttention,self).__init__(*args,**kwargs)
def forward(self,x):
B,T,D=x.size()
q,k,v=self.to_qkv(x).split(D,dim=-1)
att=torch.matmul(q,k.transpose(-2,-1))/math.sqrt(k.size(-1))
# Apply the causal mask to prevent future information leakage.
mask=causal_attention_mask(T)[None,None,:,:].to(att.device)
att+=mask[:,:,:att.shape[-2],:att.shape[-1]]
att=F.softmax(att,dim=-1)
y=torch.matmul(att,v)
return self.to_out(y)
```
阅读全文
相关推荐







