MindSpore:YOLOv3人体目标检测模型实现(三)

本文详细介绍了如何在MindSpore中实现YOLOv3人体目标检测模型的训练,包括模型训练的各个步骤,如导包、设置Context、初始化配置和网络、加载数据集、生成Momentum优化器、设置模型保存,以及训练过程中的注意事项。重点讲解了TrainYOLOv3Conf配置对象和训练过程中的关键函数,如AverageMeter、get_lr、load_yolov3_backbone等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

4.模型训练

这一部分可以参见 附件\train_yolov3.py,官方平台有更通用的训练代码(参见:https://gitee.com/mindspore/models/tree/r1.5/official/cv/yolov3_darknet53)。

4.1 导包

首先是导入包:

import time

from mindspore import context
from mindspore.common import set_seed

from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import CheckpointConfig

# 以下包请见附件代码 train_yolov3.py
from utils.utils_yolov3 import AverageMeter, get_param_groups, keep_loss_fp32, load_yolov3_backbone
from model.yolo import YOLOV3DarkNet53, YoloWithLossCell
from utils.kaiming_init import default_recurisive_init
from dataset.voc2012_dataset import create_voc2012_dataset
from utils.lr_generator import get_lr
from yolov3_train_conf import TrainYOLOv3Conf

set_seed(1)

上面的包可能有些没有用但我忘删了。

附件中引入的代码:

AverageMeter:主要是用于管理训练过程中的损失值,每个迭代一定次数给出这一段训练的损失平均值。

get_param_groups:主要是在配置优化器阶段使用。(keep_loss_fp32:我也不知道它是干什么的)

load_yolov3_backbone:主要是将Darknet53主干网络的预训练参数加载进来。

default_recurisive_init:使用He初始化,因为网络中大量使用Relu。

get_lr:学习率使用余弦退火函数生成,这个函数会根据参数生成每一代训练的学习率。

TrainYOLOv3Conf:训练配置。

4.2 设置context

然后设置Context:

# set context
context.set_context(mode=context.GRAPH_MODE,
                    device_target="GPU", save_graphs=False)
# set_graph_kernel_context
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_parallel_fusion "
                                       "--enable_trans_op_optimize "
                                       "--disable_cluster_ops=ReduceMax,Reshape "
                                       "--enable_expand_ops=Conv2D")
# Set mempool block size for improving memory utilization, which will not take effect in GRAPH_MODE
if context.get_context("mode") == 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值