PyTorch学习(9):实战

本文通过ShuffleNet和EfficientNet两个模型实例,详细介绍了如何使用PyTorch框架进行深度学习模型的搭建、数据处理、模型设计、训练配置及模型保存等流程。采用flowers17数据集进行了实验,并对比了不同训练条件下的效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

PyTorch学习(9):实战

  • Pytorch官方文档: https://pytorch-cn.readthedocs.io/zh/latest/
  • Pytorch学习文档: https://github.com/tensor-yu/PyTorch_Tutorial
  • Pytorch模型库: https://github.com/pytorch/vision/tree/master/torchvision/models


前言

Pytorch模型库已经提供了相当多可用的模型了。
本文从两个比较好的开源项目入手:ShuffleNetEfficientNet
数据集采用flowers17:
  链接: https://pan.baidu.com/s/1UxcW2AFE2OjWfgsLIHtCYA
  提取码: brxm


1.ShuffleNet系列

Github链接: https://github.com/megvii-model/ShuffleNet-Series
megvii开源的ShuffleNet系列模型
在这里插入图片描述
以ShuffleNetV2为例: 学习深度学习模型搭建的整体流程。

(1)数据处理

    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu) #锁页内存
    train_dataprovider = DataIterator(train_loader)

(2)模型设计

    architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2]
    model = ShuffleNetV2(model_size=args.model_size)

(3)训练配置

   optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

(4)训练过程

    output = model(data)
    loss = loss_function(output, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() # 对参数更新
    prec1, prec5 = accuracy(output, target, topk=(1, 5))

(5)模型保存

def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

至此,基于PyTorch框架的ShuffleNetV2模型的训练框架搭建完成。具体细节参考官方code


2.EfficientNet系列

Github链接: https://github.com/lukemelas/EfficientNet-PyTorch
非官方开源的EfficientNet系列模型,目前仅有V1版本,坐等开源V2版本。
在这里插入图片描述
作者的训练代码在examples文件夹内,整理流程同ShuffleNet系列一致。
以Efficientnet-b0模型为基础,训练flowers17数据集:

(1)数据集划分

Flowers17数据由17种不同的花构成,本文以每种花的0.8作为训练集,其余0.2作为测试集。

(2)模型训练

训练结果:

无数据增强+数据增强+预训练模型
59.56%84.92%96.32%

大致训了一下,采用预训练模型进行迁移训练的效果最好

总结

  至此,基于PyTorch框架的入门已基本完成。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Peanut_范

您的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值