Supervised Contrastive Learning的loss解读
时间: 2025-06-29 12:18:43 浏览: 12
### Supervised Contrastive Learning Loss Function Explanation
The supervised contrastive learning approach focuses on improving representation learning by leveraging labeled data to ensure that similar instances are closer together while dissimilar ones remain apart within an embedding space. This method introduces a specific form of loss function designed to optimize these relationships between samples.
#### Definition and Objective
A key aspect lies in defining positive pairs as those with the same label and negative pairs otherwise. The objective is not only to minimize distances among embeddings from positively paired examples but also maximize separations concerning negatively paired counterparts[^1].
To achieve this goal, one variant employs two margin parameters denoted as \( \lambda_1 \) and \( \lambda_2 \), alongside utilizing either maximum or minimum operations over losses calculated across all negatives per anchor-positive pair combination during training iterations.
#### Mathematical Formulation
Given an anchor sample \( x_i \) along with its corresponding set of positives \( P(i) \) and negatives \( N(i) \):
\[ L_{\text{contrast}}(i)=\sum _{{j\in P(i)}}-\log {\frac {e^{s(x_{i},x_{j})/\tau }}{\sum _{{k\in P(i)\cup N(i)}}I(k\notin P(i))\cdot e^{s(x_{i},x_{k})/\tau }}} \]
Where:
- \( s(\cdot,\cdot) \): Similarity metric such as cosine similarity.
- \( \tau \): Temperature parameter controlling peakiness of distribution.
- \( I(\cdot) \): Indicator function which equals 1 when condition inside holds true; zero elsewhere.
This formulation ensures that for each instance, representations associated with correct labels receive higher scores compared against incorrect alternatives after normalization through softmax operation applied element-wise upon exponentiated similarities scaled down via temperature factor \( \tau \)[^1].
```python
import torch
from torch import nn
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf."""
def __init__(self, temperature=0.07, base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# Compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# For numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# Tile mask
mask = mask.repeat(anchor_count, contrast_count)
# Mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# Compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# Compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# Loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
```
阅读全文
相关推荐



















