infonce损失
时间: 2025-08-08 10:41:41 浏览: 3
InfoNCE(Noise Contrastive Estimation)损失是一种广泛应用于对比学习中的目标函数,用于训练模型区分正样本与负样本[^1]。它通过最大化正样本对之间的相似度同时最小化负样本对之间的相似度来实现这一目标。在自然语言处理、计算机视觉等领域中,InfoNCE损失被用来提高模型的表示能力。
### InfoNCE Loss 定义
给定一个正样本对 $(q, k^+)$ 和一组负样本 $\{k^-\}_{i=1}^{N}$, 其中 $q$ 是查询向量,$k^+$ 是对应的正键向量,而 $\{k^-\}_{i=1}^{N}$ 表示从数据集中随机选取的一组负键向量。InfoNCE损失可以形式化为:
$$
\mathcal{L}_{\text{InfoNCE}} = -\log \frac{\exp(\text{sim}(q, k^+)/\tau)}{\sum_{j=0}^{N}\exp(\text{sim}(q, k_j)/\tau)}
$$
这里,$\text{sim}(\cdot,\cdot)$ 通常指的是余弦相似度或其他形式的点积相似度;$\tau$ 是温度参数,用来控制分布的锐利程度;分母中的求和包括了一个正样本项 $k_0=k^+$ 和所有负样本项。
### 实现方式
下面是一个使用PyTorch框架实现的简单版本的InfoNCE损失函数:
```python
import torch
import torch.nn.functional as F
def info_nce_loss(anchor, positive, negatives, temperature=0.1):
"""
anchor: Tensor of shape (batch_size, dim)
positive: Tensor of shape (batch_size, dim)
negatives: Tensor of shape (num_negatives, dim)
"""
# Normalize the vectors to unit length
anchor_norm = F.normalize(anchor, p=2, dim=-1)
positive_norm = F.normalize(positive, p=2, dim=-1)
negatives_norm = F.normalize(negatives, p=2, dim=-1)
# Compute similarities
pos_sim = (anchor_norm * positive_norm).sum(-1) / temperature
neg_sims = (anchor_norm.unsqueeze(1) * negatives_norm).sum(-1) / temperature
# Concatenate positive and negative similarities
logits = torch.cat([pos_sim.unsqueeze(1), neg_sims], dim=1)
# Labels are zero indicating that first element in each row is the positive sample
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# Cross entropy loss
loss = F.cross_entropy(logits, labels)
return loss
```
在这个例子中,`anchor`, `positive`, 和 `negatives` 分别代表锚点、正样本以及负样本的嵌入向量。这个函数首先会对这些向量进行归一化处理,然后计算它们之间的相似度,并基于此构建出用于分类任务的logits。最后,利用交叉熵损失函数来计算最终的损失值。
请注意,在实际应用中可能需要根据具体的任务需求调整上述代码,比如如何采样负样本等。
阅读全文
相关推荐


















