MMDet逐行代码解读之正负样本采样Sampler

前言

  本篇是MMdet逐行解读第四篇,代码地址:mmdet/core/bbox/samplers/random_sampler.py。随机采样正负样本主要针对在训练过程中,经过MAXIOUAssigner后,确定出每个anchor和哪个gt匹配后,从这些正负样本中采样来进行loss计算。本文以RPN的config进行讲解,因为该部分用到了随机采样来克服正负样本不平衡;而在RetinaNet中则使用focal loss来克服正负样本不平衡问题,即没有随机采样的过程。
历史文章如下:
AnchorGenerator解读
MaxIOUAssigner解读
DeltaXYWHBBoxCoder解读

1、构造一个简单的sampler

    from mmdet.core.bbox import build_sampler
    # 构造一个sampler
    sampler = dict(
        type='RandomSampler',# 构造一个随机采样器
        num=256,             # 正负样本总数量
        pos_fraction=0.5,    # 正样本比例
        neg_pos_ub=-1,       # 负样本上限
        add_gt_as_proposals=False) # 是否添加gt作为正样本,默认不添加。
    sp = build_sampler(sampler)       
    # 这块不必详细理解,知道意思即可。
    # 就是随机生成一个assigner、bboxes和gt_bboxes,
    from mmdet.core.bbox import AssignResult
    from mmdet.core.bbox.demodata import ensure_rng, random_boxes
    rng = ensure_rng(None)
    assign_result = AssignResult.random(rng=rng)
    bboxes = random_boxes(assign_result.num_preds, rng=rng)
    gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
    gt_labels = None
    # 调用sample方法进行正负样本采样
    self = sp.sample(assign_result, bboxes, gt_bboxes, gt_labels)

2、BaseSampler类

class BaseSampler(metaclass=ABCMeta):
    """Base class of samplers"""

    def __init__(self,
                 num,
                 pos_fraction,
                 neg_pos_ub=-1,
                 add_gt_as_proposals=True,
                 **kwargs):
        self.num = num
        self.pos_fraction = pos_fraction
        self.neg_pos_ub = neg_pos_ub
        self.add_gt_as_proposals = add_gt_as_proposals
        self.pos_sampler = self
        self.neg_sampler = self

    @abstractmethod
    def _sample_pos(self, assign_result, num_expected, **kwargs):
        """Sample positive samples"""
        pass

    @abstractmethod
    def _sample_neg(self, assign_result, num_expected, **kwargs):
        """Sample negative samples"""
        pass

    def sample(self,
               assign_result,
               bboxes,
               gt_bboxes,
               gt_labels=None,
               **kwargs):
        pass

  基类比较容易理解,核心是sample方法,内部调用_sample_pos方法和_sample_neg方法。后续继承该类的子类只需实现_sample_pos方法和_sample_neg方法即可。

3、RandomSampler类

3.1 sample方法

 以RandomSampler类来讲解代码。首先看下sample方法:

 # 确定正样本个数: 256*0.5 = 128
 num_expected_pos = int(self.num * self.pos_fraction)
 # 调用_sample_pos方法返回采样后正样本的id。
 pos_inds = self.pos_sampler._sample_pos(
     assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
 pos_inds = pos_inds.unique()    # 挑选出tensor独立不重复元素
 num_sampled_pos = pos_inds.numel()  # 确定出正样本个数
 num_expected_neg = self.num - num_sampled_pos #确定负样本个数
 # 由于该参数为-1,故不执行if语句,即实打实的采样254个负样本
 if self.neg_pos_ub >= 0: 
     _pos = max(1, num_sampled_pos)
     #确定负样本的上限是正样本个数的neg_pos_ub倍
     neg_upper_bound = int(self.neg_pos_ub * _pos) 
      # 负样本的个数不能超过上限      
     if num_expected_neg > neg_upper_bound:             
         num_expected_neg = neg_upper_bound              
 # 调用_sample_neg方法返回采样后负样本的id
 neg_inds = self.neg_sampler._sample_neg(
     assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
 neg_inds = neg_inds.unique() # 同理,将id取集合操作。
 # 用SamplingResult进行封装
 sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,  
                                  assign_result, gt_flags)

  代码还是比较容易理解:首先确定正样本个数,然后完成采样;采样后确定负样本的个数,若指定了负样本上限:neg_upper_bound,则负样本个数最多采样不能超过正样本个数的neg_upper_bound倍;若无指定,则负样本个数就是总的数量-正样本个数。

3.2 _sample_pos方法

  再来看下采样正样本的方法:

def _sample_pos(self, assign_result, num_expected, **kwargs):
    """Randomly sample some positive samples."""
    # 找出非0的正样本的id
    pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) 
    if pos_inds.numel() != 0:   
        pos_inds = pos_inds.squeeze(1)
    # 若正样本个数<期望的128个,则直接返回。
    if pos_inds.numel() <= num_expected: 
        return pos_inds
    # 否则就从pos_inds里面随机采够128个。
    else:
        return self.random_choice(pos_inds, num_expected)

3.2 _sample_neg方法

  和采样正样本方法大同小异,这里看下即可。

def _sample_neg(self, assign_result, num_expected, **kwargs):
    """Randomly sample some negative samples."""
    neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
    if neg_inds.numel() != 0:               
        neg_inds = neg_inds.squeeze(1)
    if len(neg_inds) <= num_expected: # 若负样本数量比期望的还小则直接返回
        return neg_inds
    else:
        return self.random_choice(neg_inds, num_expected)

总结

  下篇将开启model模块介绍,敬请期待。

可以使用`torch.utils.data.DataLoader`中的`WeightedRandomSampler`类对数据进行重采样。该类可以根据每个样本的权重进行采样,从而实现重采样的目的。具体步骤如下: 1.首先,需要计算每个样本的权重。可以根据样本的类别数量来计算每个样本的权重,使得每个类别的样本被采样的概率相等。例如,对于二分类问题,可以将正负样本的权重分别设置为1和2,这样就可以保证正负样本采样的概率相等。 2.然后,可以使用`WeightedRandomSampler`类对数据进行重采样。该类需要传入一个权重列表,用于指定每个样本的权重。可以将该类作为`DataLoader`的参数之一,从而实现对数据的重采样。 下面是一个示例代码: ```python import torch from torch.utils.data import DataLoader, WeightedRandomSampler # 假设有一个数据集 dataset,其中包含 n 个样本,每个样本的标签为 label # 首先,计算每个样本的权重 class_count = [0, 0] # 假设有两个类别,分别为 0 和 1 for _, label in dataset: class_count[label] += 1 weights = [1.0 / class_count[label] for _, label in dataset] # 然后,使用 WeightedRandomSampler 对数据进行重采样 sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True) dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) ``` 在上面的代码中,`class_count`用于统计每个类别的样本数量,`weights`用于计算每个样本的权重。`WeightedRandomSampler`的第一个参数是权重列表,第二个参数是采样的样本数量,第三个参数是是否使用重复采样。最后,将`WeightedRandomSampler`作为`DataLoader`的`sampler`参数传入即可。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值