关于focal loss和分类任务中的hard negative(positive)mining

本文探讨了深度学习中解决类别不平衡问题的策略,特别是针对目标检测和跟踪任务。focal loss作为一种hard sample mining技术,通过调整样本权重,解决了在交叉熵损失函数中hard samples被easy samples淹没的问题。focal loss通过引入调节因子,对困难负样本赋予更大权重,从而提高分类器的性能。此外,文章引用了多个跟踪和检测领域的研究,进一步阐述了focal loss在实际应用中的作用。

深度学习,数据是关键。

在训练一个分类器的时候,对数据的要求是class balance,即不同标签的样本量都要充足且相仿。然而,这个要求在现实应用中往往很难得到保证。

下面我以基于检测的单目标跟踪举例分析这个问题。

visual object tracking是在一段视频中跟踪一个特定目标。常见的方法有one-stage regression(比如correlation filter tracking)和two-stage classifcation tracking。这里我们只关注后者。two stages分别是首先在上一帧视频中目标的跟踪位置周围采样得到一堆target candidates(这与two-stageRCNN系列检测器的proposal生成是一样的意思);在第二个stage,就要使用训练所得的一个classfier来进行前景or背景的而分类

训练这个第二stage中的classfier就面临这class imbalance的问题(在two-stage检测器中同理。这就可以推广为,凡是在一个estimated bounding box周围随机采集正负样本bounding box时,都会有imbalance的问题。),即严格意义上的正样本只有一个,即the estimated bounding box,而负样本则可以是这一帧图像上除了正样本bbox之外的所有bbox。为了放宽这一要求,采用bbox IoU thresholding的方法来使得那些与正样本bbox overlap足够大的bbox也被认定为是正样本呢,多么无奈的妥协。但即便是这样,依旧存在严重的imbalance。

为了解决这种imbalance,在训练一个分类器(用cross-entropy loss)的时候,就需要设计一些hard (postive/negative)samples mining [1-5]

Focal Loss是一种针对类别不平衡的损失函数,可以在训练过程中减少易分类样本的权重,从而提高模型对难分类样本的关注度。以下是将SSD的损失函数改成focal loss的代码: ```python import torch import torch.nn as nn class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2, logits=True, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.logits = logits self.reduction = reduction def forward(self, inputs, targets): if self.logits: BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none') else: BCE_loss = nn.functional.binary_cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss if self.reduction == 'mean': return torch.mean(F_loss) elif self.reduction == 'sum': return torch.sum(F_loss) else: return F_loss class MultiBoxLoss(nn.Module): def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=True): super(MultiBoxLoss, self).__init__() self.use_gpu = use_gpu self.num_classes = num_classes self.threshold = overlap_thresh self.background_label = bkg_label self.encode_target = encode_target self.use_prior_for_matching = prior_for_matching self.do_neg_mining = neg_mining self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap self.variance = [0.1, 0.2] self.focal_loss = FocalLoss() def forward(self, predictions, targets): loc_data, conf_data, prior_data = predictions num = loc_data.size(0) num_priors = prior_data.size(0) loc_t = torch.Tensor(num, num_priors, 4) conf_t = torch.LongTensor(num, num_priors) for idx in range(num): truths = targets[idx][:, :-1].data labels = targets[idx][:, -1].data defaults = prior_data.data match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx) if self.use_gpu: loc_t = loc_t.cuda() conf_t = conf_t.cuda() pos = conf_t > 0 num_pos = pos.sum(dim=1, keepdim=True) # Localization Loss (Smooth L1) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) loss_l = nn.functional.smooth_l1_loss(loc_p, loc_t, reduction='sum') # Compute max conf across batch for hard negative mining batch_conf = conf_data.view(-1, self.num_classes) loss_c = self.focal_loss(batch_conf, conf_t.view(-1, 1)) # Hard Negative Mining loss_c[pos] = 0 # filter out pos boxes for now loss_c = loss_c.view(num, -1) _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1) neg = idx_rank < num_neg.expand_as(idx_rank) # Confidence Loss Including Positive and Negative Examples pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes) targets_weighted = conf_t[(pos + neg).gt(0)] loss_c = self.focal_loss(conf_p, targets_weighted) # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = num_pos.sum().float() loss_l /= N loss_c /= N return loss_l, loss_c ``` 在MultiBoxLoss中,我们用focal_loss替换了原来的交叉熵损失函数。在FocalLoss中,我们计算每个样本的二元交叉熵损失,然后再乘以一个类别权重系数(1 - pt)^gamma,其中pt是预测概率的指数形式,gamma是一个可调参数,用于控制易分类样本的权重。最后,我们返回一个平均的损失值。在MultiBoxLoss中,我们计算了定位损失分类损失,并将它们相加,再除以正样本的数量求取平均值。同时,我们采用了硬负样本挖掘策略,过滤掉难以分类的样本,提高模型的准确率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值