训练20遍数据集跑出的效果
from __future__ import print_function
import argparse
import math
import os
import optuna
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
from data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from layers.modules import MultiBoxLoss
from models.retinaface import RetinaFace
# 解析命令行参数
parser = argparse.ArgumentParser(description='Retinaface Training')
parser.add_argument('--training_dataset', default='./data/lst/train/label.txt', help='训练数据集目录')
parser.add_argument('--network', default='resnet50', help='Backbone 网络选择: mobile0.25 或 resnet50')
parser.add_argument('--num_workers', default=4, type=int, help='数据加载时的工作线程数')
parser.add_argument('--resume_net', default=None, help='重新训练时的已保存模型路径')
parser.add_argument('--resume_epoch', default=0, type=int, help='重新训练时的迭代轮数')
parser.add_argument('--save_folder', default='./weights/', help='保存检查点模型的目录')
# 解析参数
args = parser.parse_args()
# 如果 save_folder 目录不存在,则创建它
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
# 根据选择的网络初始化配置
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# 设置 RGB 平均值、类别数、图像维度等
rgb_mean = (104, 117, 123) # BGR 顺序
num_classes = 2
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']
num_workers = args.num_workers
training_dataset = args.training_dataset
save_folder = args.save_folder
# 超参数优化目标函数
def objective(trial):
# 超参数搜索空间
initial_lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
momentum = trial.suggest_float('momentum', 0.7, 0.99)
weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2, log=True)
gamma = trial.suggest_float('gamma', 0.1, 0.5, log=True)
# 初始化 RetinaFace 模型
net = RetinaFace(cfg=cfg)
# 如果指定了 resume_net,加载预训练权重
if args.resume_net is not None:
state_dict = torch.load(args.resume_net)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.