# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import torch.nn as nn
# main.py 文件头部添加以下内容
import matplotlib
matplotlib.use('Agg') # 用于无图形界面环境(必须放在plt导入前)
import matplotlib.pyplot as plt
from pathlib import Path
from pathlib import Path
from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from datasets import build_dataset
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
from samplers import RASampler
from augment import new_data_aug_generator
from torch.optim.lr_scheduler import CosineAnnealingLR
from contextlib import suppress
import models_mamba
import utils
# log about
import mlflow
def get_args_parser():
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
parser.add_argument('--batch-size', default=128, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') #128
parser.add_argument('--epochs', default=50, type=int, help='Total training epochs') # 延长训练周期
parser.add_argument('--bce-loss', action='store_true')
parser.add_argument('--unscale-lr', action='store_true')
# Model parameters
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input-size', default=224, type=int, help='images input size') #224
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (小数据集易过拟合)') # 增加DropPath率0.0-0.2
parser.add_argument('--model-ema', action='store_true')
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=True)
parser.add_argument('--model-ema-decay', type=float, default=0.99999, help='') #EMA模型平滑0.99996
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
#parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
# help='Clip gradient norm (default: None, no clipping)')
#parser.add_argument('--clip-grad', type=float, default=1.0, metavar='NORM',
# help='Max gradient norm')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.15,
help='weight decay (default: 0.05)') ## 增大权重衰减0.05-0.1
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
help='learning rate (default: 0.001)') #0.001
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N',
help='epochs to warmup LR, if scheduler supports') # 延长warmup5-10,15
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
help='Color jitter factor (default: 0.3)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + \
"(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.2, help='Label smoothing (default: 0.1)')#0.1 更强的标签平滑
parser.add_argument('--train-interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--repeated-aug', action='store_true')
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)
parser.add_argument('--train-mode', action='store_true')
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
parser.set_defaults(train_mode=True)
parser.add_argument('--ThreeAugment', action='store_true') #3augment
parser.add_argument('--src', action='store_true') #simple random crop
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
help='Random erase prob (小数据集建议关闭)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0.3,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)') #0.5,0.8
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"') #regnety_160 添加教师模型配置resnet50
parser.add_argument('--teacher-path', type=str, default='')
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha', default=0.7, type=float, help="") #0.5 教师模型配置
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# * Cosub params
parser.add_argument('--cosub', action='store_true')
# * Finetuning params
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
parser.add_argument('--attn-only', action='store_true')
# Dataset parameters
parser.add_argument('--data-path', default=r'D:/ru_file/ruru_file/Desktop/cifar-10-python', type=str,
help='D:/ru_file/ruru_file/Desktop/cifar-10-python')
#parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
# type=str, help='Image Net dataset path')
parser.add_argument('--data-set', default='CIFAR10', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Dataset type (CIFAR10, CIFAR100, IMNET, etc)')
parser.add_argument('--inat-category', default='name',
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
type=str, help='semantic granularity')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
# amp about
parser.add_argument('--if_amp', action='store_true', help='Enable NVIDIA Apex AMP training')
parser.add_argument('--no_amp', action='store_false', dest='if_amp')
parser.set_defaults(if_amp=True)
# if continue with inf
parser.add_argument('--if_continue_inf', action='store_true')
parser.add_argument('--no_continue_inf', action='store_false', dest='if_continue_inf')
parser.set_defaults(if_continue_inf=False)
# if use nan to num
parser.add_argument('--if_nan2num', action='store_true')
parser.add_argument('--no_nan2num', action='store_false', dest='if_nan2num')
parser.set_defaults(if_nan2num=False)
# if use random token position
#parser.add_argument('--if_random_cls_token_position', action='store_true')
#parser.add_argument('--no_random_cls_token_position', action='store_false', dest='if_random_cls_token_position')
#parser.set_defaults(if_random_cls_token_position=False)
# if use random token rank
#parser.add_argument('--if_random_token_rank', action='store_true')
#parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank')
#parser.set_defaults(if_random_token_rank=False)
parser.add_argument('--local-rank', default=0, type=int)
parser.add_argument('--use-acmix', action='store_true', help='启用ACMix模块') #加入新参数
parser.add_argument('--clip-grad', type=float, default=1.0, help='梯度裁剪阈值') #加入新参数
parser.add_argument('--acmix-kernel', type=int, default=5, help='ACMix卷积核大小') #加入新参数
parser.add_argument('--d-state', type=int, default=16, help='SSM状态维度') #加入新参数,16
parser.add_argument('--expand', type=int, default=4, help='SSM扩展因子') #加入新参数,2,1
#parser.add_argument('--input-size', type=int, default=32, help='CIFAR-10 的标准输入尺寸为 32x32')
parser.add_argument('--early-stop', type=int, default=10, help='Early stopping patience')#15
return parser
def plot_loss_curves(train_loss_history, val_loss_history, output_dir):
"""绘制训练和验证损失曲线并保存图像"""
plt.figure(figsize=(10, 6))
plt.plot(train_loss_history, label='Training Loss', marker='o')
plt.plot(val_loss_history, label='Validation Loss', marker='x')
plt.title('Training and Validation Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# 保存图像到输出目录
loss_path = Path(output_dir) / 'loss_curves.png'
plt.savefig(loss_path)
plt.close()
# 记录到MLflow(如果已初始化)
if mlflow.active_run():
mlflow.log_artifact(str(loss_path))
def main(args):
utils.init_distributed_mode(args)
print(args)
if args.distillation_type != 'none' and args.finetune and not args.eval:
raise NotImplementedError("Finetuning with distillation not yet supported")
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
# log about
run_name = args.output_dir.split("/")[-1]
if args.local_rank == 0 and args.gpu == 0:
mlflow.start_run(run_name=run_name)
for key, value in vars(args).items():
mlflow.log_param(key, value)
if args.data_set in ['CIFAR10', 'CIFAR100']:
args.repeated_aug = False # 强制关闭重复增强
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)
# 检查验证集的类别数是否与训练集一致
assert args.nb_classes == dataset_val.nb_classes, "类别数不一致:训练集 {} vs 验证集 {}".format(args.nb_classes,dataset_val.nb_classes)
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
if args.repeated_aug:
sampler_train = RASampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if args.ThreeAugment:
data_loader_train.dataset.transform = new_data_aug_generator(args)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
#batch_size=int(1.5 * args.batch_size * 20),
batch_size=int(args.batch_size * 2),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
if args.data_set in ['CIFAR10', 'CIFAR100']:
args.input_size = 32 # 强制输入尺寸为32
print(f"Creating model: {args.model}")
# 在调用 create_model() 之前添加以下代码
model_kwargs = dict(
img_size=args.input_size, # 确保 img_size 在此处定义
use_acmix=args.use_acmix,
acmix_kernel=args.acmix_kernel,
d_state=args.d_state,
expand=args.expand,
rms_norm=False, # 确保禁用RMSNorm
num_classes=args.nb_classes,
)
model = create_model(
args.model,
pretrained=False,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
#use_acmix=args.use_acmix, # 传递参数
#acmix_kernel=args.acmix_kernel, # 传递参数
#img_size=args.input_size
** model_kwargs # 传递所有额外参数
)
# 检查并调整分类头的输出维度
if hasattr(model, 'head') and len(model.head) > 0 and hasattr(model.head[-1], 'out_features'):
if model.head[-1].out_features != args.nb_classes:
print(f"Adjusting model head from {model.head[-1].out_features} to {args.nb_classes} classes")
model.head[-1] = nn.Linear(model.head[-1].in_features, args.nb_classes)
else:
# 如果模型没有标准的分类头,直接检查输出层
if model.num_features != args.nb_classes:
print(f"Adjusting model output dimension from {model.num_features} to {args.nb_classes}")
model.head = nn.Linear(model.num_features, args.nb_classes)
model.to(device)# 确保模型先移动到设备上
print("model_kwargs 内容检查:", model_kwargs)
print(f"Model device: {next(model.parameters()).device}")
# 仅当使用Vim模型时添加ACMix参数
if args.model.startswith('vim'):
model_kwargs['use_acmix'] = args.use_acmix
model_kwargs['acmix_kernel'] = args.acmix_kernel
model = create_model(args.model, **model_kwargs)
model_without_ddp = model
optimizer = create_optimizer(args, model_without_ddp) # 确保 model_without_ddp 在此处赋值
# 创建模型后添加
print(f"Model output dimension: {model.head[-1].out_features}")
assert model.head[-1].out_features == args.nb_classes, \
f"Output dimension mismatch! Model:{model.head[-1].out_features}, Dataset:{args.nb_classes}"
# 确保模型在设备上后进行前向传播测试
#with torch.no_grad():
# test_input = torch.randn(2, 3, args.input_size, args.input_size).to(device)
# print(f"Test input device: {test_input.device}")
# output = model(test_input)
# print(f"Test output shape: {output.shape}, should be (2, {args.nb_classes})")
# 替换原有调度器
lr_scheduler = CosineAnnealingLR(
optimizer,
T_max=args.epochs * len(data_loader_train),
eta_min=1e-6
)
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model']
model = model.to(device) # 再次将模型移动到 GPU
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# 移除所有分类头相关权重
#keys_to_remove = [k for k in checkpoint_model.keys()
# if 'head' in k or 'cls_token' in k]
#for k in keys_to_remove:
# print(f"Removing key {k} from pretrained checkpoint")
# del checkpoint_model[k]
# 如果存在 'head' 键,删除分类头权重
#if 'head' in checkpoint_model:
# del checkpoint_model['head'] # 删除分类头权重
# interpolate position embedding
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
model.load_state_dict(checkpoint_model, strict=False)
if args.attn_only:
for name_p,p in model.named_parameters():
if '.attn.' in name_p:
p.requires_grad = True
else:
p.requires_grad = False
try:
model.head.weight.requires_grad = True
model.head.bias.requires_grad = True
except:
model.fc.weight.requires_grad = True
model.fc.bias.requires_grad = True
try:
model.pos_embed.requires_grad = True
except:
print('no position encoding')
try:
for p in model.patch_embed.parameters():
p.requires_grad = False
except:
print('no patch embed')
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model = torch.nn.DataParallel(model)
model = model.to(device)
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
if not args.unscale_lr:
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
args.lr = linear_scaled_lr
# 加入以下代码
if args.use_acmix:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.8 # 将基础学习率降低20%
# amp about
#amp_autocast = suppress
#loss_scaler = "none"
if args.if_amp:
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
else:
amp_autocast = suppress
loss_scaler = "none"
if args.data_set in ['CIFAR10', 'CIFAR100']:
args.warmup_epochs = max(3, args.warmup_epochs) # CIFAR最少3个epoch预热
#lr_scheduler, _ = create_scheduler(args, optimizer)
criterion = LabelSmoothingCrossEntropy()
if mixup_active:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing, num_classes=args.nb_classes) #添加num_classes
else:
criterion = torch.nn.CrossEntropyLoss()
if args.bce_loss:
criterion = torch.nn.BCEWithLogitsLoss()
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()
# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
# add ema load
if args.model_ema:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.model_ema:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint and args.if_amp: # change loss_scaler if not amp
loss_scaler.load_state_dict(checkpoint['scaler'])
elif 'scaler' in checkpoint and not args.if_amp:
loss_scaler = 'none'
#lr_scheduler.step(args.start_epoch)
if args.eval:
test_stats = evaluate(data_loader_val, model, device, amp_autocast)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
test_stats = evaluate(data_loader_val, model_ema.ema, device, amp_autocast)
print(f"Accuracy of the ema network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
# log about
if args.local_rank == 0 and args.gpu == 0:
mlflow.log_param("n_parameters", n_parameters)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
no_improve_epochs = 0 # 添加在epoch循环前
# 新增:初始化损失记录列表
train_loss_history = []
val_loss_history = []
best_val_acc = 0.0
no_improve_epochs = 0
patience = 10 # 连续15轮无提升则停止,早停机制 15
for epoch in range(args.start_epoch, args.epochs):
'''
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if test_stats["acc1"] > max_accuracy:
max_accuracy = test_stats["acc1"]
no_improve_epochs = 0
else:
no_improve_epochs += 1
if no_improve_epochs >= 10: # 连续10轮无提升停止
print(f"No improvement for {no_improve_epochs} epochs, early stopping!")
break
# 在 train_one_epoch 调用前添加梯度裁剪逻辑
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
# 修改后的 train_one_epoch 调用
train_stats = train_one_epoch(
model,criterion,data_loader_train,
optimizer,device,epoch,
loss_scaler=loss_scaler,
amp_autocast=amp_autocast, # 传入混合精度上下文
max_norm=args.clip_grad, # 传入梯度裁剪阈值
model_ema=model_ema,
mixup_fn=mixup_fn,
set_training_mode=True, # 强制启用训练模式
args=args
)
'''
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch,
loss_scaler=loss_scaler,
amp_autocast=amp_autocast, # 传入混合精度上下文
max_norm=args.clip_grad, # 传入梯度裁剪阈值
model_ema=model_ema,
mixup_fn=mixup_fn,
set_training_mode=True, # 强制启用训练模式
args=args,
lr_scheduler=lr_scheduler
)
# 修改后的梯度监控代码
if any(p.grad is not None for p in model.parameters()):
grad_norm = torch.norm(torch.stack([
torch.norm(p.grad)
for p in model.parameters()
if p.grad is not None # 只处理有梯度的参数
]))
print(f"Gradient Norm: {grad_norm.item():.4f}")
else:
print("Warning: All gradients are None!")
test_stats = evaluate(data_loader_val, model, device, amp_autocast)
# 添加以下代码调试
if args.cosub:
outputs = torch.split(outputs, outputs.shape[0] // 2, dim=0)
# 检查每个分支的输出维度
assert outputs[0].shape[
1] == args.nb_classes, f"Output dim mismatch! Model: {outputs[0].shape[1]}, Dataset: {args.nb_classes}"
else:
# 检查输出维度是否匹配
assert outputs.shape[
1] == args.nb_classes, f"Output dim mismatch! Model: {outputs.shape[1]}, Dataset: {args.nb_classes}"
# 每5个epoch检查一次模型输出维度
if epoch % 5 == 0:
print(f"Epoch {epoch} - Model output dim: {model_without_ddp.head[-1].out_features}")
# 动态调整mixup强度
if epoch > args.epochs // 2:
mixup_fn = Mixup(mixup_alpha=0.5) # 后期减少mixup强度
# ============ 新增:记录损失值 ============
train_loss = train_stats['loss']
val_loss = test_stats['loss']
train_loss_history.append(train_loss)
val_loss_history.append(val_loss)
# ============ MLflow日志记录 ============
if args.local_rank == 0 and args.gpu == 0:
mlflow.start_run(run_name=args.output_dir.split("/")[-1])
mlflow.log_metric("train_loss", train_loss, step=epoch)
mlflow.log_metric("val_loss", val_loss, step=epoch)
# 记录参数
for key, value in vars(args).items():
mlflow.log_param(key, value)
# 记录文件(如损失曲线)
#mlflow.log_artifact(os.path.join(args.output_dir, "loss_curves.png"))
# 更新最大准确率和早停计数器
if test_stats["acc1"] > max_accuracy:
max_accuracy = test_stats["acc1"]
no_improve_epochs = 0
else:
no_improve_epochs += 1
# 更新最佳准确率和早停计数器
current_val_acc = test_stats["acc1"]
if current_val_acc > best_val_acc:
best_val_acc = current_val_acc
no_improve_epochs = 0
# 保存最佳模型
utils.save_on_master({...}, 'best_model.pth')
else:
no_improve_epochs += 1
if no_improve_epochs >= patience:
print(f"早停在epoch {epoch},最佳准确率: {best_val_acc:.2f}%")
break
# 添加过拟合诊断
if val_loss != 0: # 避免除以零
overfit_index = train_loss / val_loss
if overfit_index < 0.7: # 健康范围0.8-1.2
print(f"警告:过拟合风险高!指数={overfit_index:.2f}")
# 早停判断
#if no_improve_epochs >= 15: # 给予更多收敛时间
# print("Early stopping!")
# break
#lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), # 保存调度器状态
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler,
'args': args,
}, checkpoint_path)
test_stats = evaluate(data_loader_val, model, device, amp_autocast)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
# 添加过拟合诊断(在训练循环结束后)
if val_loss != 0: # 避免除以零
overfit_index = train_loss / val_loss
if overfit_index < 0.7: # 健康范围0.8-1.2
print(f"警告:过拟合风险高!指数={overfit_index:.2f}")
if max_accuracy < test_stats["acc1"]:
max_accuracy = test_stats["acc1"]
if args.output_dir:
checkpoint_paths = [output_dir / 'best_checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler,
'args': args,
}, checkpoint_path)
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
# log about
if args.local_rank == 0 and args.gpu == 0:
for key, value in log_stats.items():
mlflow.log_metric(key, value, log_stats['epoch'])
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
# 新增调用
if utils.is_main_process() and args.output_dir:
plot_loss_curves(train_loss_history, val_loss_history, args.output_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
args.gpu = None
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
如何修改上述 问题