Center loss-pytorch代码详解

该博客介绍了PyTorch中实现中心损失函数的代码,用于深度人脸识别。中心损失通过最小化每个类别的特征向量到其类中心的距离,帮助学习区分性强的特征。关键步骤包括计算特征和中心之间的欧氏距离,以及应用掩码保留对应类别的距离。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先上github代码:https://github.com/KaiyangZhou/pytorch-center-loss

主要代码如下:

class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

参数:

num_classes:  数据集类别数

feat_dim: 特征向量的维度

forward部分代码解析:

这里通过举例来说明代码,假设num_classes=6,即标签从0-5。

当前mini-batch中batch size=3,label=0,4,2,feat_dim=5

x: [B, feat_dim]=[3, 5]

使用S0, S4,S2来表示输入x中的三个特征,其中Si维度是5:

x=\begin{pmatrix} S_0 \\ S_4 \\ S_2 \end{pmatrix}

同样,centers:[num_classes, feat_dim]=[6,5]可表示为,其中Ci的维度是5:

centers=\begin{pmatrix} C_0 \\ C_1 \\ C_2 \\ C_3 \\ C_4 \\ C_5 \end{pmatrix}

此处代码运行得到的结果是:

distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()

该行代码首先对x逐元素平方,再求和,左后再expand到[B,num_classes]=[3,6]维度:

S_0=(S_{00}, S_{01},S_{02},S_{03},S_{04}, S_{05})

S_0^2=S_{00}^2+S_{01}^2+...+S_{05}^2

对centers的操作也同理,最后,得到的distmat为:

distmat=\begin{pmatrix} C_0^2+S_0^2& C_1^2+S_0^2& C_2^2+S_0^2& C_3^2+S_0^2& C_4^2+S_0^2& C_5^2+S_0^2\\ C_0^2+S_4^2& C_1^2+S_4^2& C_2^2+S_4^2& C_3^2+S_4^2& C_4^2+S_4^2& C_5^2+S_4^2\\ C_0^2+S_2^2& C_1^2+S_2^2& C_2^2+S_2^2& C_3^2+S_2^2& C_4^2+S_2^2& C_5^2+S_2^2\\ \end{pmatrix}

distmat.addmm_(1, -2, x, self.centers.t())

上述代码运算为:distmat=1\cdot distmat-2\cdot x\cdot centers^T,即:

distmat=\begin{pmatrix} (C_0-S_0)^2& (C_1-S_0)^2& (C_2-S_0)^2&(C_3-S_0)^2& (C_4-S_0)^2& (C_5-S_0)^2\\ (C_0-S_4)^2& (C_1-S_4)^2& (C_2-S_4)^2&(C_3-S_4)^2& (C_4-S_4)^2& (C_5-S_4)^2\\ (C_0-S_2)^2& (C_1-S_2)^2& (C_2-S_2)^2&(C_3-S_2)^2& (C_4-S_2)^2& (C_5-S_2)^2 \end{pmatrix}

classes = torch.arange(self.num_classes).long()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))

该段代码意义如下:

dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

 dist*mask恰好保留了输入x三个特征与对应类中心的距离平方,最后的距离之和为:

dist=\frac{1}{3}[(C_0-S_0)^2+(C_4-S_4)^2+(C_2-S_2)^2]

最后,clamp使其在[1e-12,1e+12]的范围内

CenterLoss是一种用于增强深度学习模型分类性能的损失函数,它通过学习每个类别的中心点来减小样本与其对应类别中心之间的距离。在PyTorch中,有不同的实现方式。根据提供的引用,可以看到有两种不同的实现方式。 引用展示了一种简单直观的实现方式,定义了一个名为CenterLoss2的类,其中包含了一个num_class x num_feature维的参数centers,代表每个类别的中心点。在forward方法中,通过计算样本与其对应类别中心之间的距离dist,然后对dist进行限制和平均操作,得到最终的损失值loss。 引用展示了另一种实现方式,其中通过使用一个mask来选择与每个样本对应的中心点,然后计算样本与其对应类别中心之间的距离dist,并对dist进行平均操作,得到最终的损失值loss。 此外,引用展示了计算样本与每个中心点之间距离的方法,通过对样本和中心点进行平方和操作,得到一个距离矩阵distmat。 因此,根据不同的实现方式和计算方法,可以使用CenterLoss来优化深度学习模型的分类性能。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [center loss pytorch实现总结](https://blog.csdn.net/qq_45759229/article/details/126917939)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [【LossCenter loss代码详解pytorch)](https://blog.csdn.net/m0_51358406/article/details/122312950)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值