使用pdb进行detr3d代码走读(二)

使用pdb进行detr3d代码走读(一)-CSDN博客

书接上文。

---------

第二次调试:

----------

经过刚才的第一轮,调整一下断点的位置,这次可以直接把断点打在detr3d的代码里面。这次设置两个断点位置,分别是detr3d的类初始化位置和predict推理位置。

from typing import Dict, List, Optional
 
import torch
from torch import Tensor
 
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.bbox_3d.utils import get_lidar2img
from .grid_mask import GridMask
 
 
@MODELS.register_module()
class DETR3D(MVXTwoStageDetector):
    """DETR3D: 3D Object Detection from Multi-view Images via 3D-to-2D Queries
    Args:
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`Det3DDataPreprocessor`. Defaults to None.
        use_grid_mask (bool) : Data augmentation. Whether to mask out some
            grids during extract_img_feat. Defaults to False.
        img_backbone (dict, optional): Backbone of extracting
            images feature. Defaults to None.
        img_neck (dict, optional): Neck of extracting
            image features. Defaults to None.
        pts_bbox_head (dict, optional): Bboxes head of
            detr3d. Defaults to None.
        train_cfg (dict, optional): Train config of model.
            Defaults to None.
        test_cfg (dict, optional): Train config of model.
            Defaults to None.
        init_cfg (dict, optional): Initialize config of
            model. Defaults to None.
    """
 
    def __init__(self,
                 data_preprocessor=None,
                 use_grid_mask=False,
                 img_backbone=None,
                 img_neck=None,
                 pts_bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        breakpoint()
        super(DETR3D, self).__init__(
            img_backbone=img_backbone,
            img_neck=img_neck,
            pts_bbox_head=pts_bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            data_preprocessor=data_preprocessor)
        self.grid_mask = GridMask(
            True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
        self.use_grid_mask = use_grid_mask
 
    def extract_img_feat(self, img: Tensor,
                         batch_input_metas: List[dict]) -> List[Tensor]:
        """Extract features from images.
        Args:
            img (tensor): Batched multi-view image tensor with
                shape (B, N, C, H, W).
            batch_input_metas (list[dict]): Meta information of multiple inputs
                in a batch.
        Returns:
             list[tensor]: multi-level image features.
        """
 
        B = img.size(0)
        if img is not None:
            input_shape = img.shape[-2:]  # bs nchw
            # update real input shape of each single img
            for img_meta in batch_input_metas:
                img_meta.update(input_shape=input_shape)
 
            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                img = img.view(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)  # mask out some grids
            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)
 
        img_feats_reshaped = []
        for img_feat in img_feats:
            BN, C, H, W = img_feat.size()
            img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
        return img_feats_reshaped
 
    def extract_feat(self, batch_inputs_dict: Dict,
                     batch_input_metas: List[dict]) -> List[Tensor]:
        """Extract features from images.
        Refer to self.extract_img_feat()
        """
        imgs = batch_inputs_dict.get('imgs', None)
        img_feats = self.extract_img_feat(imgs, batch_input_metas)
        return img_feats
 
    def _forward(self):
        raise NotImplementedError('tensor mode is yet to add')
 
    # original forward_train
    def loss(self, batch_inputs_dict: Dict[List, Tensor],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
        """
        Args:
            batch_inputs_dict (dict): The model input dict which include
                `imgs` keys.
                - imgs (torch.Tensor): Tensor of batched multi-view  images.
                    It has shape (B, N, C, H ,W)
            batch_data_samples (List[obj:`Det3DDataSample`]): The Data Samples
                It usually includes information such as `gt_instance_3d`.
        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        batch_input_metas = [item.metainfo for item in batch_data_samples]
        batch_input_metas = self.add_lidar2img(batch_input_metas)
        img_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
        outs = self.pts_bbox_head(img_feats, batch_input_metas, **kwargs)
 
        batch_gt_instances_3d = [
            item.gt_instances_3d for item in batch_data_samples
        ]
        loss_inputs = [batch_gt_instances_3d, outs]
        losses_pts = self.pts_bbox_head.loss_by_feat(*loss_inputs)
 
        return losses_pts
 
    # original simple_test
    def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
                batch_data_samples: List[Det3DDataSample],
                **kwargs) -> List[Det3DDataSample]:
        """Forward of testing.
        Args:
            batch_inputs_dict (dict): The model input dict which include
                `imgs` keys.
                - imgs (torch.Tensor): Tensor of batched multi-view images.
                    It has shape (B, N, C, H ,W)
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input sample. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.
            - scores_3d (Tensor): Classification scores, has a shape
                (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                contains a tensor with shape (num_instances, 9).
        """
        breakpoint()
        batch_input_metas = [item.metainfo for item in batch_data_samples]
        batch_input_metas = self.add_lidar2img(batch_input_metas)
        img_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
        outs = self.pts_bbox_head(img_feats, batch_input_metas)
 
        results_list_3d = self.pts_bbox_head.predict_by_feat(
            outs, batch_input_metas, **kwargs)
 
        # change the bboxes' format
        detsamples = self.add_pred_to_datasample(batch_data_samples,
                                                 results_list_3d)
        return detsamples
 
    # may need speed-up
    def add_lidar2img(self, batch_input_metas: List[Dict]) -> List[Dict]:
        """add 'lidar2img' transformation matrix into batch_input_metas.
        Args:
            batch_input_metas (list[dict]): Meta information of multiple inputs
                in a batch.
        Returns:
            batch_input_metas (list[dict]): Meta info with lidar2img added
        """
        for meta in batch_input_metas:
            l2i = list()
            for i in range(len(meta['cam2img'])):
                c2i = torch.tensor(meta['cam2img'][i]).double()
                l2c = torch.tensor(meta['lidar2cam'][i]).double()
                l2i.append(get_lidar2img(c2i, l2c).float().numpy())
            meta['lidar2img'] = l2i
        return batch_input_metas

借助一下科技:

下面聚焦以下的方向进行探索:

  1. detr3d的代码到底是怎么实现的
  2. detr3d模型到底是怎么训练的
  3. detr3d的loss是怎么设计的
  4. detr3d是怎么得到最终的推理结果的
  5. detr3d到底有没有用到nms

带着这个问题开始调试:

首先这里可以大致了解到,首先,从batch_data_samples读取数据样本,然后,首先extract feat图像特征,然后,通过一个3d bbox特征头进行处理,最后,输出的task feature通过解析predict by feat 来获得3d框结果也就是pred。

下面就进入到更具体的位置。断点调试到获取图像特征的地方:

可以看到主要的环节是一个backbone和neck的结构。我们先按下gridmask不表,后面再回顾gridmask的作用,先观察backbone和neck的具体工作。

我们的断点打到了resnet,也就是说,detr3d这里是用resnet作为图像backbone的。

继续,我们可以看到neck是一个FPN:

总的来说,通过一个resnet+fpn,我们从图像得到了图像特征。

根据图像特征以及其他的信息,我们进入bboxhead,经过断点,我们进入到detr3d_head::forward

    def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],
                **kwargs) -> Dict[str, Tensor]:
        """Forward function.

        Args:
            mlvl_feats (List[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head,
                shape [nb_dec, bs, num_query, cls_out_channels]. Note
                cls_out_channels should includes background.
            all_bbox_preds (Tensor): Sigmoid outputs from the regression
                head with normalized coordinate format
                (cx, cy, l, w, cz, h, sin(φ), cos(φ), vx, vy).
                Shape [nb_dec, bs, num_query, 10].
        """
        query_embeds = self.query_embedding.weight
        hs, init_reference, inter_references = self.transformer(
            mlvl_feats,
            query_embeds,
            reg_branches=self.reg_branches if self.with_box_refine else None,
            img_metas=img_metas,
            **kwargs)
        hs = hs.permute(0, 2, 1, 3)
        outputs_classes = []
        outputs_coords = []

        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.cls_branches[lvl](hs[lvl])
            tmp = self.reg_branches[lvl](hs[lvl])  # shape: ([B, num_q, 10])
            # TODO: check the shape of reference
            assert reference.shape[-1] == 3
            tmp[..., 0:2] += reference[..., 0:2]
            tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
            tmp[..., 4:5] += reference[..., 2:3]
            tmp[..., 4:5] = tmp[..., 4:5].sigmoid()

            tmp[..., 0:1] = \
                tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) \
                + self.pc_range[0]
            tmp[..., 1:2] = \
                tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) \
                + self.pc_range[1]
            tmp[..., 4:5] = \
                tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) \
                + self.pc_range[2]

            # TODO: check if using sigmoid
            outputs_coord = tmp
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)

        outputs_classes = torch.stack(outputs_classes)
        outputs_coords = torch.stack(outputs_coords)
        outs = {
            'all_cls_scores': outputs_classes,
            'all_bbox_preds': outputs_coords,
            'enc_cls_scores': None,
            'enc_bbox_preds': None,
        }
        return outs

不要忘记我们调用forward的输入是outs = self.pts_bbox_head(img_feats, batch_input_metas)。

这个函数返回score和pred bbox。

开始调试这段:

首先我们看到它从query_embedding中获取了一个查询用的嵌入权重——query_embeds,然后,用transformer:

传入:多级特征mlvl_feats,query_embeds权重,以及其他的一些参数。

我们根据这个情况,直接打断点到transformer中:

    def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs):
        """Forward function for `Detr3DTransformer`.
        Args:
            mlvl_feats (list(Tensor)): Input queries from
                different level. Each element has shape
                (B, N, C, H_lvl, W_lvl).
            query_embed (Tensor): The query positional and semantic embedding
                for decoder, with shape [num_query, c+c].
            mlvl_pos_embeds (list(Tensor)): The positional encoding
                of feats from different level, has the shape
                [bs, N, embed_dims, h, w]. It is unused here.
            reg_branches (obj:`nn.ModuleList`): Regression heads for
                feature maps from each decoder layer. Only would
                be passed when `with_box_refine` is True. Default to None.
        Returns:
            tuple[Tensor]: results of decoder containing the following tensor.
                - inter_states: Outputs from decoder. If
                    return_intermediate_dec is True output has shape
                      (num_dec_layers, bs, num_query, embed_dims), else has
                      shape (1, bs, num_query, embed_dims).
                - init_reference_out: The initial value of reference
                    points, has shape (bs, num_queries, 4).
                - inter_references_out: The internal value of reference
                    points in decoder, has shape
                    (num_dec_layers, bs, num_query, embed_dims)
        """
        assert query_embed is not None
        bs = mlvl_feats[0].size(0)
        query_pos, query = torch.split(query_embed, self.embed_dims, dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)  # [bs,num_q,c]
        query = query.unsqueeze(0).expand(bs, -1, -1)  # [bs,num_q,c]
        reference_points = self.reference_points(query_pos)
        reference_points = reference_points.sigmoid()
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2)
        query_pos = query_pos.permute(1, 0, 2)
        inter_states, inter_references = self.decoder(
            query=query,
            key=None,
            value=mlvl_feats,
            query_pos=query_pos,
            reference_points=reference_points,
            reg_branches=reg_branches,
            **kwargs)

        inter_references_out = inter_references
        return inter_states, init_reference_out, inter_references_out

首先是要从query_embed中提取查询的位置,以及查询本身,也就是query 和 query pos,然后要对参考点应用sigmoid。

mlvl feat,reference points和reg branches都要传给decoder方法,结果为inter states 和 inter references中。

decoder是核心,进入到这个函数,它的forward从属于Detr3DTransformerDecoder。

这个方法接受一个query张量,reference points和reg branches

for循环中迭代应用每个layer,得到output,如果reg branches,对reference points进行调整。

有6层layer:

我们有必要看一下layer是什么。打断点进去看,原来是mmcv::transformer,那没事了。

论文精读:《DETR3D: 3D Object Detection from Multi-view Images via 3D-to-2D Queries》-CSDN博客

我们目前已经知道了4层FPN的不同尺度的图像特征,6路相机,使用resnet+fpn来从图像中获取pv feature。我们也知道了reference point是通过object query embedding中回归出来的3d位置参考点。但我们还没有了解全部。

回到detr3d head的代码中,我们通过倒查可以看到cls_branch和output_coord,这分别代表目标框的类别信息和检测要素(比如位置、尺寸、朝向等)信息,也就是图中step6。

我们打印cls和reg这两个branch,可以看到,cls和reg分别是linear-layernorm-relu和linear-relu的堆叠。

抑制过拟合——从梯度的角度看LayerNorm的作用_为什么不用laynorm会有梯度问题-CSDN博客

也就是说,结合我们对论文的阅读,这中间的过程,也就是我们要去确认代码实现的地方是以下的几个步骤:

  1. 初始化query,这些query是预测框的中心点坐标
  2. 结合相机标定得到各路相机视角下的图像特征采样点
  3. 对图像特征进行采样,对不同通路的图像特征进行汇总后,做预测框的位置微调(图像特征不同尺度的金字塔需要通过插值来保持协调)

另外还有一点,直到现在,我们还没有看到lidar在其中发挥的作用。多模态之间的融合是另一个话题,暂不讨论。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值