Retinaface训练超参数调优

训练20遍数据集跑出的效果

from __future__ import print_function

import argparse
import math
import os

import optuna
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data

from data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from layers.modules import MultiBoxLoss
from models.retinaface import RetinaFace

# 解析命令行参数
parser = argparse.ArgumentParser(description='Retinaface Training')
parser.add_argument('--training_dataset', default='./data/lst/train/label.txt', help='训练数据集目录')
parser.add_argument('--network', default='resnet50', help='Backbone 网络选择: mobile0.25 或 resnet50')
parser.add_argument('--num_workers', default=4, type=int, help='数据加载时的工作线程数')
parser.add_argument('--resume_net', default=None, help='重新训练时的已保存模型路径')
parser.add_argument('--resume_epoch', default=0, type=int, help='重新训练时的迭代轮数')
parser.add_argument('--save_folder', default='./weights/', help='保存检查点模型的目录')

# 解析参数
args = parser.parse_args()

# 如果 save_folder 目录不存在,则创建它
if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)

# 根据选择的网络初始化配置
cfg = None
if args.network == "mobile0.25":
    cfg = cfg_mnet
elif args.network == "resnet50":
    cfg = cfg_re50

# 设置 RGB 平均值、类别数、图像维度等
rgb_mean = (104, 117, 123)  # BGR 顺序
num_classes = 2
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']

num_workers = args.num_workers
training_dataset = args.training_dataset
save_folder = args.save_folder


# 超参数优化目标函数
def objective(trial):
    # 超参数搜索空间
    initial_lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    momentum = trial.suggest_float('momentum', 0.7, 0.99)
    weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2, log=True)
    gamma = trial.suggest_float('gamma', 0.1, 0.5, log=True)

    # 初始化 RetinaFace 模型
    net = RetinaFace(cfg=cfg)

    # 如果指定了 resume_net,加载预训练权重
    if args.resume_net is not None:
        state_dict = torch.load(args.resume_net)
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值