SSD: Single Shot MultiBox Detector 训练KITTI数据集(2)

前言

博主在上篇中花了很大篇幅讲解如何一步步把KITTI原始数据做成了SSD可以训练的格式,接下来就可以使用相关caffe代码实现SSD的训练了。

下载VGG预训练模型

将 SSD 用于自己的检测任务,是需要 Fine-tuning a pretrained network,看过论文的朋友可能都知道,论文中的SSD框架是是由VGG网络为基底(base)的。除此之外,作者也提供了另外两种结构的网络:ZF-SSD和Resnet-SSD,可以在caffe/examples/ssd 文件夹中查看相应的python训练代码。

初次训练,还是用VGG网络吧,下载该预训练模型,将其放到/home/mx/caffe/models/VGGNet文件夹之下。

修改训练代码

这里先说下电脑硬件配置,这决定了一些训练参数的设定。博主使用的主机CPU为Intel i7 6700,GPU为TITAN X,运行内存16GB,搭载Ubuntu16.04系统。

一般的caffe训练都是使用train.prototxt和solver.prototxt文件,有同学可能想问了,为什么SSD项目下找不到这些文件呢?原因是SSD的模型很大,train.prototxt就有1000多行,直接修改参数的工作量太大,而且train.prototxt一旦改动,test.prototxt和solver.prototxt也要跟着改动。因此,作者使用了一个很有效的方法,利用python脚本,自动生成这些文件。初次训练,博主选择了ssd_pascal.py 脚本来训练SSD_300x300,那么将ssd_pascal.py复制一份,重命名为ssd_pascal_kitti.py,然后修改这个ython文件。

ssd_pascal.py脚本也有500多行,博主也不可能全部贴出来,这里就贴出可以修改的部分,作为一个参照。

PS.根据博友反馈,还是提供修改过的的训练脚本 ssd_pascal_kitti.py,以供参考。

自定义路径和常用参数

train_data = "examples/VOC0712/VOC0712_trainval_lmdb" # 训练数据路径,修改前
train_data = "examples/KITTI/KITTI_trainval_lmdb" # 修改后
------
test_data = "examples/VOC0712/VOC0712_test_lmdb" # 测试数据路径
te
### Kitti 数据集目标识别方法 Kitti 数据集中包含了丰富的标注信息,这些信息对于开发和评估计算机视觉算法至关重要[^2]。具体来说,该数据集提供了多种类型的标注: - **物体检测框**:包括3D和2D边界框,用于检测和识别不同类型的物体(如车辆、行人、骑车人等)。 - **道路和车道标注**:有助于实现自动驾驶场景下的路径规划与环境理解。 #### 使用KITTI数据集进行目标检测的方法概述 为了利用KITTI数据集执行高效的目标检测任务,通常会采用深度学习框架配合预训练模型来加速收敛并提高精度。以下是几种常见的方式: 1. **基于卷积神经网络(CNN)** 的架构被广泛应用于此领域内,比如Faster R-CNN, SSD (Single Shot MultiBox Detector), 和YOLO (You Only Look Once)家族成员都取得了不错的效果。 2. 对于特定应用场景下(例如仅限于汽车),可以考虑微调已有的预训练权重文件以适应新的类别分布情况;而对于更复杂的情况,则可能需要重新训练整个网络直至获得满意的结果。 #### 实现示例代码展示 下面给出一段简单的Python脚本作为如何加载并可视化KITTI格式的数据样本的例子: ```python import os from PIL import Image import matplotlib.pyplot as plt def load_kitti_image(image_path): """Load an image from the specified path.""" img = Image.open(image_path) return img def plot_bounding_boxes(img, labels_file): """Plot bounding boxes on top of given image using label information.""" with open(labels_file, 'r') as f: lines = f.readlines() fig, ax = plt.subplots(1) ax.imshow(img) for line in lines: parts = line.strip().split(' ') class_name = parts[0] bbox = list(map(float, parts[4:8])) rect = plt.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1], fill=False, edgecolor='red', linewidth=2) ax.add_patch(rect) ax.text(bbox[0], bbox[1], class_name, style='italic', bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 1}) plt.show() if __name__ == "__main__": base_dir = '/path/to/kitti/dataset' image_id = '000001' # Example ID image_filename = os.path.join(base_dir, 'training/image_2/', '{}.png'.format(image_id)) label_filename = os.path.join(base_dir, 'training/label_2/', '{}.txt'.format(image_id)) loaded_img = load_kitti_image(image_filename) plot_bounding_boxes(loaded_img, label_filename) ``` 这段代码展示了怎样读取一张图片及其对应的标签文件,并绘制出所有的边界框位置。这只是一个基础版本,在实际项目中还需要加入更多的功能模块来进行完整的对象检测流程设计。
评论 157
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值