Supervised Contrastive Learning代码
时间: 2025-06-23 20:25:28 浏览: 27
### Supervised Contrastive Learning 的代码实现
监督对比学习(Supervised Contrastive Learning, SCL)是一种改进特征表示的方法,它通过优化样本之间的相对距离来增强模型的学习能力。下面是一个基于PyTorch的SCL实现例子[^1]。
#### 主要组件定义
为了构建一个完整的SCL训练流程,需要定义几个核心部分:
- **数据加载器(DataLoader)**:用于准备输入数据集。
- **损失函数(Loss Function)**:计算不同类别间样本的距离差异。
- **网络结构(Network Architecture)**:设计神经网络架构以提取有用的特征向量。
- **优化器(Optimizer)**:更新参数使目标最小化。
#### 数据预处理与加载
首先设置好数据管道,这里假设已经有一个自定义的数据集类`CustomDataset`实现了必要的接口方法如`__len__()`, `__getitem__()`等。
```python
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_dataset = CustomDataset(root='./data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
```
#### 定义损失函数
接着创建一个专门针对SCL定制化的损失模块——`SupConLoss`:
```python
class SupConLoss(nn.Module):
"""Supervised Contrastive Loss"""
def __init__(self, temperature=0.07, base_temperature=0.07):
super().__init__()
self.temperature = temperature
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))
# 计算相似度矩阵
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# 构建掩码
if labels is not None:
mask = torch.eq(labels[:, None], labels[None, :]).float().to(device)
pos_mask = mask - torch.eye(mask.shape[0], dtype=torch.float).to(device)
neg_mask = 1 - mask
log_prob = logits - torch.log(torch.exp(logits).sum(dim=-1, keepdim=True))
mean_log_prob_pos = (pos_mask * log_prob).sum(-1) / pos_mask.sum(-1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
return loss
```
#### 创建并初始化模型
选择合适的骨干网作为基础编码器,并将其嵌入到更复杂的体系结构中去适应特定的任务需求。
```python
import timm
encoder = timm.create_model('resnet50', pretrained=True, num_classes=0)
model = nn.Sequential(encoder, nn.Linear(in_features=2048, out_features=128)) # 假设最后层有2048个单元格
criterion = SupConLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
```
#### 开始训练过程
利用上述组件完成整个迭代循环,在每轮结束时评估性能指标变化情况以便调整超参配置。
```python
for epoch in range(num_epochs):
model.train()
running_loss = []
for images, targets in train_loader:
optimizer.zero_grad()
bsz = images.size()[0]
f1, f2 = [model(img.cuda()) for img in [images, images]] # 获取两个视图下的特征表达
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
features = F.normalize(features, p=2, dim=2)
features = torch.reshape(features, (-1, features.shape[-1]))
loss = criterion(features, targets.repeat(2))
loss.backward()
optimizer.step()
running_loss.append(loss.item())
avg_train_loss = sum(running_loss)/len(running_loss)
print(f'Epoch [{epoch}/{num_epochs}], Train Loss: {avg_train_loss:.4f}')
```
此段代码展示了如何在一个简单的场景下应用监督对比学习算法进行图像分类任务中的特征提取工作。当然实际操作过程中还需要考虑更多细节问题比如正则项加入、早停机制设定以及多GPU支持等等。
阅读全文
相关推荐



















