书接上文。
---------
第二次调试:
----------
经过刚才的第一轮,调整一下断点的位置,这次可以直接把断点打在detr3d的代码里面。这次设置两个断点位置,分别是detr3d的类初始化位置和predict推理位置。
from typing import Dict, List, Optional
import torch
from torch import Tensor
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.bbox_3d.utils import get_lidar2img
from .grid_mask import GridMask
@MODELS.register_module()
class DETR3D(MVXTwoStageDetector):
"""DETR3D: 3D Object Detection from Multi-view Images via 3D-to-2D Queries
Args:
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
use_grid_mask (bool) : Data augmentation. Whether to mask out some
grids during extract_img_feat. Defaults to False.
img_backbone (dict, optional): Backbone of extracting
images feature. Defaults to None.
img_neck (dict, optional): Neck of extracting
image features. Defaults to None.
pts_bbox_head (dict, optional): Bboxes head of
detr3d. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
data_preprocessor=None,
use_grid_mask=False,
img_backbone=None,
img_neck=None,
pts_bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
breakpoint()
super(DETR3D, self).__init__(
img_backbone=img_backbone,
img_neck=img_neck,
pts_bbox_head=pts_bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor)
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
self.use_grid_mask = use_grid_mask
def extract_img_feat(self, img: Tensor,
batch_input_metas: List[dict]) -> List[Tensor]:
"""Extract features from images.
Args:
img (tensor): Batched multi-view image tensor with
shape (B, N, C, H, W).
batch_input_metas (list[dict]): Meta information of multiple inputs
in a batch.
Returns:
list[tensor]: multi-level image features.
"""
B = img.size(0)
if img is not None:
input_shape = img.shape[-2:] # bs nchw
# update real input shape of each single img
for img_meta in batch_input_metas:
img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.view(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img) # mask out some grids
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
def extract_feat(self, batch_inputs_dict: Dict,
batch_input_metas: List[dict]) -> List[Tensor]:
"""Extract features from images.
Refer to self.extract_img_feat()
"""
imgs = batch_inputs_dict.get('imgs', None)
img_feats = self.extract_img_feat(imgs, batch_input_metas)
return img_feats
def _forward(self):
raise NotImplementedError('tensor mode is yet to add')
# original forward_train
def loss(self, batch_inputs_dict: Dict[List, Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
`imgs` keys.
- imgs (torch.Tensor): Tensor of batched multi-view images.
It has shape (B, N, C, H ,W)
batch_data_samples (List[obj:`Det3DDataSample`]): The Data Samples
It usually includes information such as `gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
batch_input_metas = self.add_lidar2img(batch_input_metas)
img_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
outs = self.pts_bbox_head(img_feats, batch_input_metas, **kwargs)
batch_gt_instances_3d = [
item.gt_instances_3d for item in batch_data_samples
]
loss_inputs = [batch_gt_instances_3d, outs]
losses_pts = self.pts_bbox_head.loss_by_feat(*loss_inputs)
return losses_pts
# original simple_test
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
`imgs` keys.
- imgs (torch.Tensor): Tensor of batched multi-view images.
It has shape (B, N, C, H ,W)
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 9).
"""
breakpoint()
batch_input_metas = [item.metainfo for item in batch_data_samples]
batch_input_metas = self.add_lidar2img(batch_input_metas)
img_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
outs = self.pts_bbox_head(img_feats, batch_input_metas)
results_list_3d = self.pts_bbox_head.predict_by_feat(
outs, batch_input_metas, **kwargs)
# change the bboxes' format
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return detsamples
# may need speed-up
def add_lidar2img(self, batch_input_metas: List[Dict]) -> List[Dict]:
"""add 'lidar2img' transformation matrix into batch_input_metas.
Args:
batch_input_metas (list[dict]): Meta information of multiple inputs
in a batch.
Returns:
batch_input_metas (list[dict]): Meta info with lidar2img added
"""
for meta in batch_input_metas:
l2i = list()
for i in range(len(meta['cam2img'])):
c2i = torch.tensor(meta['cam2img'][i]).double()
l2c = torch.tensor(meta['lidar2cam'][i]).double()
l2i.append(get_lidar2img(c2i, l2c).float().numpy())
meta['lidar2img'] = l2i
return batch_input_metas
借助一下科技:
下面聚焦以下的方向进行探索:
- detr3d的代码到底是怎么实现的
- detr3d模型到底是怎么训练的
- detr3d的loss是怎么设计的
- detr3d是怎么得到最终的推理结果的
- detr3d到底有没有用到nms
带着这个问题开始调试:
首先这里可以大致了解到,首先,从batch_data_samples读取数据样本,然后,首先extract feat图像特征,然后,通过一个3d bbox特征头进行处理,最后,输出的task feature通过解析predict by feat 来获得3d框结果也就是pred。
下面就进入到更具体的位置。断点调试到获取图像特征的地方:
可以看到主要的环节是一个backbone和neck的结构。我们先按下gridmask不表,后面再回顾gridmask的作用,先观察backbone和neck的具体工作。
我们的断点打到了resnet,也就是说,detr3d这里是用resnet作为图像backbone的。
继续,我们可以看到neck是一个FPN:
总的来说,通过一个resnet+fpn,我们从图像得到了图像特征。
根据图像特征以及其他的信息,我们进入bboxhead,经过断点,我们进入到detr3d_head::forward
def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],
**kwargs) -> Dict[str, Tensor]:
"""Forward function.
Args:
mlvl_feats (List[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
shape [nb_dec, bs, num_query, cls_out_channels]. Note
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format
(cx, cy, l, w, cz, h, sin(φ), cos(φ), vx, vy).
Shape [nb_dec, bs, num_query, 10].
"""
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references = self.transformer(
mlvl_feats,
query_embeds,
reg_branches=self.reg_branches if self.with_box_refine else None,
img_metas=img_metas,
**kwargs)
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[lvl])
tmp = self.reg_branches[lvl](hs[lvl]) # shape: ([B, num_q, 10])
# TODO: check the shape of reference
assert reference.shape[-1] == 3
tmp[..., 0:2] += reference[..., 0:2]
tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
tmp[..., 4:5] += reference[..., 2:3]
tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
tmp[..., 0:1] = \
tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) \
+ self.pc_range[0]
tmp[..., 1:2] = \
tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) \
+ self.pc_range[1]
tmp[..., 4:5] = \
tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) \
+ self.pc_range[2]
# TODO: check if using sigmoid
outputs_coord = tmp
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outs = {
'all_cls_scores': outputs_classes,
'all_bbox_preds': outputs_coords,
'enc_cls_scores': None,
'enc_bbox_preds': None,
}
return outs
不要忘记我们调用forward的输入是outs = self.pts_bbox_head(img_feats, batch_input_metas)。
这个函数返回score和pred bbox。
开始调试这段:
首先我们看到它从query_embedding中获取了一个查询用的嵌入权重——query_embeds,然后,用transformer:
传入:多级特征mlvl_feats,query_embeds权重,以及其他的一些参数。
我们根据这个情况,直接打断点到transformer中:
def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
(B, N, C, H_lvl, W_lvl).
query_embed (Tensor): The query positional and semantic embedding
for decoder, with shape [num_query, c+c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, N, embed_dims, h, w]. It is unused here.
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape
(num_dec_layers, bs, num_query, embed_dims), else has
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference
points in decoder, has shape
(num_dec_layers, bs, num_query, embed_dims)
"""
assert query_embed is not None
bs = mlvl_feats[0].size(0)
query_pos, query = torch.split(query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) # [bs,num_q,c]
query = query.unsqueeze(0).expand(bs, -1, -1) # [bs,num_q,c]
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=mlvl_feats,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
首先是要从query_embed中提取查询的位置,以及查询本身,也就是query 和 query pos,然后要对参考点应用sigmoid。
mlvl feat,reference points和reg branches都要传给decoder方法,结果为inter states 和 inter references中。
decoder是核心,进入到这个函数,它的forward从属于Detr3DTransformerDecoder。
这个方法接受一个query张量,reference points和reg branches
for循环中迭代应用每个layer,得到output,如果reg branches,对reference points进行调整。
有6层layer:
我们有必要看一下layer是什么。打断点进去看,原来是mmcv::transformer,那没事了。
论文精读:《DETR3D: 3D Object Detection from Multi-view Images via 3D-to-2D Queries》-CSDN博客
我们目前已经知道了4层FPN的不同尺度的图像特征,6路相机,使用resnet+fpn来从图像中获取pv feature。我们也知道了reference point是通过object query embedding中回归出来的3d位置参考点。但我们还没有了解全部。
回到detr3d head的代码中,我们通过倒查可以看到cls_branch和output_coord,这分别代表目标框的类别信息和检测要素(比如位置、尺寸、朝向等)信息,也就是图中step6。
我们打印cls和reg这两个branch,可以看到,cls和reg分别是linear-layernorm-relu和linear-relu的堆叠。
也就是说,结合我们对论文的阅读,这中间的过程,也就是我们要去确认代码实现的地方是以下的几个步骤:
- 初始化query,这些query是预测框的中心点坐标
- 结合相机标定得到各路相机视角下的图像特征采样点
- 对图像特征进行采样,对不同通路的图像特征进行汇总后,做预测框的位置微调(图像特征不同尺度的金字塔需要通过插值来保持协调)
另外还有一点,直到现在,我们还没有看到lidar在其中发挥的作用。多模态之间的融合是另一个话题,暂不讨论。