python使用RegNetY 模型加载图片特征

import torch
import timm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from scipy.spatial.distance import cosine
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1. 加载预训练的 RegNetY 模型
model = timm.create_model('regnety_1280', pretrained=True)  # 选择不同模型 'regnety_800mf', 'regnety_1_6gf'
model.eval()  # 设置为评估模式

# 移除分类头,仅保留特征提取部分
model.head = torch.nn.Identity()

# 2. 图像预处理函数
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 3. 加载和预处理图像
def preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')  # 打开并转换为 RGB
    return transform(img).unsqueeze(0)  # 添加批量维度

# 4. 提取特征
def extract_features(image_path):
    img_tensor = preprocess_image(image_path)  # 预处理图像
    with torch.no_grad():  # 禁用梯度计算
        features = model(img_tensor)  # 获取特征
    return features.squeeze()  # 返回去除批量维度后的特征

# 5. 计算余弦相似度
def cosine_similarity(features1, features2):
    # 将特征展平为 1D 向量
    features1 = features1.flatten()
    features2 = features2.flatten()
    # 计算余弦相似度
    return 1 - cosine(features1.numpy(), features2.numpy())  # 余弦相似度越接近1表示越相似


# 6. 计算欧氏距离
def euclidean_distance(features1, features2):
    return torch.norm(features1 - features2).item()

# 7. 加载两张图片并提取特征
image1_path = '1.jpg'
image2_path = '2.jpg'

features1 = extract_features(image1_path)
features2 = extract_features(image2_path)

# 8. 计算相似度
similarity_cosine = cosine_similarity(features1, features2)
similarity_euclidean = euclidean_distance(features1, features2)

# 9. 打印结果
print(f"Cosine Similarity: {similarity_cosine:.4f}")
print(f"Euclidean Distance: {similarity_euclidean:.4f}")

# Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json import torch.nn as nn # main.py 文件头部添加以下内容 import matplotlib matplotlib.use('Agg') # 用于无图形界面环境(必须放在plt导入前) import matplotlib.pyplot as plt from pathlib import Path from pathlib import Path from timm.data import Mixup from timm.models import create_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma from datasets import build_dataset from engine import train_one_epoch, evaluate from losses import DistillationLoss from samplers import RASampler from augment import new_data_aug_generator from torch.optim.lr_scheduler import CosineAnnealingLR from contextlib import suppress import models_mamba import utils # log about import mlflow def get_args_parser(): parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) parser.add_argument('--batch-size', default=128, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') #128 parser.add_argument('--epochs', default=50, type=int, help='Total training epochs') # 延长训练周期 parser.add_argument('--bce-loss', action='store_true') parser.add_argument('--unscale-lr', action='store_true') # Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--input-size', default=224, type=int, help='images input size') #224 parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT', help='Drop path rate (小数据集易过拟合)') # 增加DropPath率0.0-0.2 parser.add_argument('--model-ema', action='store_true') parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') parser.set_defaults(model_ema=True) parser.add_argument('--model-ema-decay', type=float, default=0.99999, help='') #EMA模型平滑0.99996 parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') #parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', # help='Clip gradient norm (default: None, no clipping)') #parser.add_argument('--clip-grad', type=float, default=1.0, metavar='NORM', # help='Max gradient norm') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.15, help='weight decay (default: 0.05)') ## 增大权重衰减0.05-0.1 # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.001)') #0.001 parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N', help='epochs to warmup LR, if scheduler supports') # 延长warmup5-10,15 parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation parameters parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', help='Color jitter factor (default: 0.3)') parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + \ "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.2, help='Label smoothing (default: 0.1)')#0.1 更强的标签平滑 parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') parser.add_argument('--repeated-aug', action='store_true') parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') parser.set_defaults(repeated_aug=True) parser.add_argument('--train-mode', action='store_true') parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') parser.set_defaults(train_mode=True) parser.add_argument('--ThreeAugment', action='store_true') #3augment parser.add_argument('--src', action='store_true') #simple random crop # * Random Erase params parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', help='Random erase prob (小数据集建议关闭)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') # * Mixup params parser.add_argument('--mixup', type=float, default=0.3, help='mixup alpha, mixup enabled if > 0. (default: 0.8)') #0.5,0.8 parser.add_argument('--cutmix', type=float, default=1.0, help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # Distillation parameters parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', help='Name of teacher model to train (default: "regnety_160"') #regnety_160 添加教师模型配置resnet50 parser.add_argument('--teacher-path', type=str, default='') parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") parser.add_argument('--distillation-alpha', default=0.7, type=float, help="") #0.5 教师模型配置 parser.add_argument('--distillation-tau', default=1.0, type=float, help="") # * Cosub params parser.add_argument('--cosub', action='store_true') # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') parser.add_argument('--attn-only', action='store_true') # Dataset parameters parser.add_argument('--data-path', default=r'D:/ru_file/ruru_file/Desktop/cifar-10-python', type=str, help='D:/ru_file/ruru_file/Desktop/cifar-10-python') #parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], # type=str, help='Image Net dataset path') parser.add_argument('--data-set', default='CIFAR10', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'], type=str, help='Dataset type (CIFAR10, CIFAR100, IMNET, etc)') parser.add_argument('--inat-category', default='name', choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], type=str, help='semantic granularity') parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', help='') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') # amp about parser.add_argument('--if_amp', action='store_true', help='Enable NVIDIA Apex AMP training') parser.add_argument('--no_amp', action='store_false', dest='if_amp') parser.set_defaults(if_amp=True) # if continue with inf parser.add_argument('--if_continue_inf', action='store_true') parser.add_argument('--no_continue_inf', action='store_false', dest='if_continue_inf') parser.set_defaults(if_continue_inf=False) # if use nan to num parser.add_argument('--if_nan2num', action='store_true') parser.add_argument('--no_nan2num', action='store_false', dest='if_nan2num') parser.set_defaults(if_nan2num=False) # if use random token position #parser.add_argument('--if_random_cls_token_position', action='store_true') #parser.add_argument('--no_random_cls_token_position', action='store_false', dest='if_random_cls_token_position') #parser.set_defaults(if_random_cls_token_position=False) # if use random token rank #parser.add_argument('--if_random_token_rank', action='store_true') #parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank') #parser.set_defaults(if_random_token_rank=False) parser.add_argument('--local-rank', default=0, type=int) parser.add_argument('--use-acmix', action='store_true', help='启用ACMix模块') #加入新参数 parser.add_argument('--clip-grad', type=float, default=1.0, help='梯度裁剪阈值') #加入新参数 parser.add_argument('--acmix-kernel', type=int, default=5, help='ACMix卷积核大小') #加入新参数 parser.add_argument('--d-state', type=int, default=16, help='SSM状态维度') #加入新参数,16 parser.add_argument('--expand', type=int, default=4, help='SSM扩展因子') #加入新参数,2,1 #parser.add_argument('--input-size', type=int, default=32, help='CIFAR-10 的标准输入尺寸为 32x32') parser.add_argument('--early-stop', type=int, default=10, help='Early stopping patience')#15 return parser def plot_loss_curves(train_loss_history, val_loss_history, output_dir): """绘制训练和验证损失曲线并保存图像""" plt.figure(figsize=(10, 6)) plt.plot(train_loss_history, label='Training Loss', marker='o') plt.plot(val_loss_history, label='Validation Loss', marker='x') plt.title('Training and Validation Loss Curves') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True) # 保存图像到输出目录 loss_path = Path(output_dir) / 'loss_curves.png' plt.savefig(loss_path) plt.close() # 记录到MLflow(如果已初始化) if mlflow.active_run(): mlflow.log_artifact(str(loss_path)) def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True # log about run_name = args.output_dir.split("/")[-1] if args.local_rank == 0 and args.gpu == 0: mlflow.start_run(run_name=run_name) for key, value in vars(args).items(): mlflow.log_param(key, value) if args.data_set in ['CIFAR10', 'CIFAR100']: args.repeated_aug = False # 强制关闭重复增强 dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) # 检查验证集的类别数是否与训练集一致 assert args.nb_classes == dataset_val.nb_classes, "类别数不一致:训练集 {} vs 验证集 {}".format(args.nb_classes,dataset_val.nb_classes) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) if args.ThreeAugment: data_loader_train.dataset.transform = new_data_aug_generator(args) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, #batch_size=int(1.5 * args.batch_size * 20), batch_size=int(args.batch_size * 2), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) if args.data_set in ['CIFAR10', 'CIFAR100']: args.input_size = 32 # 强制输入尺寸为32 print(f"Creating model: {args.model}") # 在调用 create_model() 之前添加以下代码 model_kwargs = dict( img_size=args.input_size, # 确保 img_size 在此处定义 use_acmix=args.use_acmix, acmix_kernel=args.acmix_kernel, d_state=args.d_state, expand=args.expand, rms_norm=False, # 确保禁用RMSNorm num_classes=args.nb_classes, ) model = create_model( args.model, pretrained=False, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, #use_acmix=args.use_acmix, # 传递参数 #acmix_kernel=args.acmix_kernel, # 传递参数 #img_size=args.input_size ** model_kwargs # 传递所有额外参数 ) # 检查并调整分类头的输出维度 if hasattr(model, 'head') and len(model.head) > 0 and hasattr(model.head[-1], 'out_features'): if model.head[-1].out_features != args.nb_classes: print(f"Adjusting model head from {model.head[-1].out_features} to {args.nb_classes} classes") model.head[-1] = nn.Linear(model.head[-1].in_features, args.nb_classes) else: # 如果模型没有标准的分类头,直接检查输出层 if model.num_features != args.nb_classes: print(f"Adjusting model output dimension from {model.num_features} to {args.nb_classes}") model.head = nn.Linear(model.num_features, args.nb_classes) model.to(device)# 确保模型先移动到设备上 print("model_kwargs 内容检查:", model_kwargs) print(f"Model device: {next(model.parameters()).device}") # 仅当使用Vim模型时添加ACMix参数 if args.model.startswith('vim'): model_kwargs['use_acmix'] = args.use_acmix model_kwargs['acmix_kernel'] = args.acmix_kernel model = create_model(args.model, **model_kwargs) model_without_ddp = model optimizer = create_optimizer(args, model_without_ddp) # 确保 model_without_ddp 在此处赋值 # 创建模型后添加 print(f"Model output dimension: {model.head[-1].out_features}") assert model.head[-1].out_features == args.nb_classes, \ f"Output dimension mismatch! Model:{model.head[-1].out_features}, Dataset:{args.nb_classes}" # 确保模型在设备上后进行前向传播测试 #with torch.no_grad(): # test_input = torch.randn(2, 3, args.input_size, args.input_size).to(device) # print(f"Test input device: {test_input.device}") # output = model(test_input) # print(f"Test output shape: {output.shape}, should be (2, {args.nb_classes})") # 替换原有调度器 lr_scheduler = CosineAnnealingLR( optimizer, T_max=args.epochs * len(data_loader_train), eta_min=1e-6 ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] model = model.to(device) # 再次将模型移动到 GPU state_dict = model.state_dict() for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # 移除所有分类头相关权重 #keys_to_remove = [k for k in checkpoint_model.keys() # if 'head' in k or 'cls_token' in k] #for k in keys_to_remove: # print(f"Removing key {k} from pretrained checkpoint") # del checkpoint_model[k] # 如果存在 'head' 键,删除分类头权重 #if 'head' in checkpoint_model: # del checkpoint_model['head'] # 删除分类头权重 # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) if args.attn_only: for name_p,p in model.named_parameters(): if '.attn.' in name_p: p.requires_grad = True else: p.requires_grad = False try: model.head.weight.requires_grad = True model.head.bias.requires_grad = True except: model.fc.weight.requires_grad = True model.fc.bias.requires_grad = True try: model.pos_embed.requires_grad = True except: print('no position encoding') try: for p in model.patch_embed.parameters(): p.requires_grad = False except: print('no patch embed') model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model = torch.nn.DataParallel(model) model = model.to(device) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) if not args.unscale_lr: linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr # 加入以下代码 if args.use_acmix: for param_group in optimizer.param_groups: param_group['lr'] *= 0.8 # 将基础学习率降低20% # amp about #amp_autocast = suppress #loss_scaler = "none" if args.if_amp: amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() else: amp_autocast = suppress loss_scaler = "none" if args.data_set in ['CIFAR10', 'CIFAR100']: args.warmup_epochs = max(3, args.warmup_epochs) # CIFAR最少3个epoch预热 #lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if mixup_active: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing, num_classes=args.nb_classes) #添加num_classes else: criterion = torch.nn.CrossEntropyLoss() if args.bce_loss: criterion = torch.nn.BCEWithLogitsLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss( criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau ) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) # add ema load if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint and args.if_amp: # change loss_scaler if not amp loss_scaler.load_state_dict(checkpoint['scaler']) elif 'scaler' in checkpoint and not args.if_amp: loss_scaler = 'none' #lr_scheduler.step(args.start_epoch) if args.eval: test_stats = evaluate(data_loader_val, model, device, amp_autocast) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") test_stats = evaluate(data_loader_val, model_ema.ema, device, amp_autocast) print(f"Accuracy of the ema network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return # log about if args.local_rank == 0 and args.gpu == 0: mlflow.log_param("n_parameters", n_parameters) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 no_improve_epochs = 0 # 添加在epoch循环前 # 新增:初始化损失记录列表 train_loss_history = [] val_loss_history = [] best_val_acc = 0.0 no_improve_epochs = 0 patience = 10 # 连续15轮无提升则停止,早停机制 15 for epoch in range(args.start_epoch, args.epochs): ''' if args.distributed: data_loader_train.sampler.set_epoch(epoch) if test_stats["acc1"] > max_accuracy: max_accuracy = test_stats["acc1"] no_improve_epochs = 0 else: no_improve_epochs += 1 if no_improve_epochs >= 10: # 连续10轮无提升停止 print(f"No improvement for {no_improve_epochs} epochs, early stopping!") break # 在 train_one_epoch 调用前添加梯度裁剪逻辑 if args.clip_grad is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) # 修改后的 train_one_epoch 调用 train_stats = train_one_epoch( model,criterion,data_loader_train, optimizer,device,epoch, loss_scaler=loss_scaler, amp_autocast=amp_autocast, # 传入混合精度上下文 max_norm=args.clip_grad, # 传入梯度裁剪阈值 model_ema=model_ema, mixup_fn=mixup_fn, set_training_mode=True, # 强制启用训练模式 args=args ) ''' train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler=loss_scaler, amp_autocast=amp_autocast, # 传入混合精度上下文 max_norm=args.clip_grad, # 传入梯度裁剪阈值 model_ema=model_ema, mixup_fn=mixup_fn, set_training_mode=True, # 强制启用训练模式 args=args, lr_scheduler=lr_scheduler ) # 修改后的梯度监控代码 if any(p.grad is not None for p in model.parameters()): grad_norm = torch.norm(torch.stack([ torch.norm(p.grad) for p in model.parameters() if p.grad is not None # 只处理有梯度的参数 ])) print(f"Gradient Norm: {grad_norm.item():.4f}") else: print("Warning: All gradients are None!") test_stats = evaluate(data_loader_val, model, device, amp_autocast) # 添加以下代码调试 if args.cosub: outputs = torch.split(outputs, outputs.shape[0] // 2, dim=0) # 检查每个分支的输出维度 assert outputs[0].shape[ 1] == args.nb_classes, f"Output dim mismatch! Model: {outputs[0].shape[1]}, Dataset: {args.nb_classes}" else: # 检查输出维度是否匹配 assert outputs.shape[ 1] == args.nb_classes, f"Output dim mismatch! Model: {outputs.shape[1]}, Dataset: {args.nb_classes}" # 每5个epoch检查一次模型输出维度 if epoch % 5 == 0: print(f"Epoch {epoch} - Model output dim: {model_without_ddp.head[-1].out_features}") # 动态调整mixup强度 if epoch > args.epochs // 2: mixup_fn = Mixup(mixup_alpha=0.5) # 后期减少mixup强度 # ============ 新增:记录损失值 ============ train_loss = train_stats['loss'] val_loss = test_stats['loss'] train_loss_history.append(train_loss) val_loss_history.append(val_loss) # ============ MLflow日志记录 ============ if args.local_rank == 0 and args.gpu == 0: mlflow.start_run(run_name=args.output_dir.split("/")[-1]) mlflow.log_metric("train_loss", train_loss, step=epoch) mlflow.log_metric("val_loss", val_loss, step=epoch) # 记录参数 for key, value in vars(args).items(): mlflow.log_param(key, value) # 记录文件(如损失曲线) #mlflow.log_artifact(os.path.join(args.output_dir, "loss_curves.png")) # 更新最大准确率和早停计数器 if test_stats["acc1"] > max_accuracy: max_accuracy = test_stats["acc1"] no_improve_epochs = 0 else: no_improve_epochs += 1 # 更新最佳准确率和早停计数器 current_val_acc = test_stats["acc1"] if current_val_acc > best_val_acc: best_val_acc = current_val_acc no_improve_epochs = 0 # 保存最佳模型 utils.save_on_master({...}, 'best_model.pth') else: no_improve_epochs += 1 if no_improve_epochs >= patience: print(f"早停在epoch {epoch},最佳准确率: {best_val_acc:.2f}%") break # 添加过拟合诊断 if val_loss != 0: # 避免除以零 overfit_index = train_loss / val_loss if overfit_index < 0.7: # 健康范围0.8-1.2 print(f"警告:过拟合风险高!指数={overfit_index:.2f}") # 早停判断 #if no_improve_epochs >= 15: # 给予更多收敛时间 # print("Early stopping!") # break #lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), # 保存调度器状态 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler, 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device, amp_autocast) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") # 添加过拟合诊断(在训练循环结束后) if val_loss != 0: # 避免除以零 overfit_index = train_loss / val_loss if overfit_index < 0.7: # 健康范围0.8-1.2 print(f"警告:过拟合风险高!指数={overfit_index:.2f}") if max_accuracy < test_stats["acc1"]: max_accuracy = test_stats["acc1"] if args.output_dir: checkpoint_paths = [output_dir / 'best_checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler, 'args': args, }, checkpoint_path) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} # log about if args.local_rank == 0 and args.gpu == 0: for key, value in log_stats.items(): mlflow.log_metric(key, value, log_stats['epoch']) if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) # 新增调用 if utils.is_main_process() and args.output_dir: plot_loss_curves(train_loss_history, val_loss_history, args.output_dir) if __name__ == '__main__': parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) args = parser.parse_args() args.gpu = None if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args) 如何修改上述 问题
06-10
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值