4.模型训练
这一部分可以参见 附件\train_yolov3.py,官方平台有更通用的训练代码(参见:https://gitee.com/mindspore/models/tree/r1.5/official/cv/yolov3_darknet53)。
4.1 导包
首先是导入包:
import time
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import CheckpointConfig
# 以下包请见附件代码 train_yolov3.py
from utils.utils_yolov3 import AverageMeter, get_param_groups, keep_loss_fp32, load_yolov3_backbone
from model.yolo import YOLOV3DarkNet53, YoloWithLossCell
from utils.kaiming_init import default_recurisive_init
from dataset.voc2012_dataset import create_voc2012_dataset
from utils.lr_generator import get_lr
from yolov3_train_conf import TrainYOLOv3Conf
set_seed(1)
上面的包可能有些没有用但我忘删了。
附件中引入的代码:
AverageMeter:主要是用于管理训练过程中的损失值,每个迭代一定次数给出这一段训练的损失平均值。
get_param_groups:主要是在配置优化器阶段使用。(keep_loss_fp32:我也不知道它是干什么的)
load_yolov3_backbone:主要是将Darknet53主干网络的预训练参数加载进来。
default_recurisive_init:使用He初始化,因为网络中大量使用Relu。
get_lr:学习率使用余弦退火函数生成,这个函数会根据参数生成每一代训练的学习率。
TrainYOLOv3Conf:训练配置。
4.2 设置context
然后设置Context:
# set context
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
# set_graph_kernel_context
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_parallel_fusion "
"--enable_trans_op_optimize "
"--disable_cluster_ops=ReduceMax,Reshape "
"--enable_expand_ops=Conv2D")
# Set mempool block size for improving memory utilization, which will not take effect in GRAPH_MODE
if context.get_context("mode") ==