Pytorch入门实战第四周:猴痘病识别

本文记录了使用PyTorch构建CNN网络进行猴痘病图片分类的过程。包括前期准备,如设置GPU、导入数据;构建CNN网络并推导结构;训练模型,设置超参数、编写训练和测试函数;结果可视化,绘制Loss和Accuracy图、指定图片预测;保存并加载模型。还尝试多种方法提高精度,如修改优化器、增加dropout层、调整学习率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

前言

一、前期准备

1.1 设置GPU

1.2 导入数据

二、构建CNN网络

2.1 网络结构推导

2.2 代码实现

三.训练模型

3.1 设置超参数:损失函数、学习率、优化器

3.2 编写训练函数

3.3 编写测试函数

3.4 正式训练

四、结果可视化

4.1 Loss和Accuracy图

4.2 指定图片进行预测

五、保存并加载模型

六、提高精度尝试

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客
  • 🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)

说在前面:

  • 本周学习目标:在训练过程中保存效果最好的模型参数,加载最佳模型参数并识别本地的一张图片,调整网络结构使得测试集accuracy达到88%;在此基础上试着调整模型参数并观察测试集的准确率变化,尝试设置动态学习率,测试集accuracy到达90%
  • 学习重点:指定图片预测、保存并加载模型
  • 我的环境:Python3.8、Pycharm2020、torch1.12.1+cu113(数据来源——[K同学啊](https://mtyjkh.blog.csdn.net/))

一、前期准备

1.1 设置GPU

#一、前期准备
'''
1.1设置GPU
'''
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, random
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

结果输出:cuda

1.2 导入数据

1.数据下载并设置相应文件目录以便图片读取,将数据下载放在对应代码下目录下新建的4-data目录

  • 数据集介绍:包含了两类图片,为猴痘病和不是猴痘病的图片,分为两个文件夹

导入数据的步骤

 1)使用函数将字符串类型的文件夹路径转换为pathlib.Path对象

 2)使用glob方法获取data_dir路径下的所有文件路径,并以列表的形式存储在data_paths中

 3)利用split()函数对data_paths中的每个文件路径执行分割操作,获取各个文件所属的类别名称并储存在classNames中

数据文件导入代码如下:

'''
1.2 导入数据
'''
data_dir = './4-data/'
data_dir = pathlib.Path(data_dir)           #使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象
data_paths = list(data_dir.glob('*'))       #使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中
classNames = [str(path).split("\\")[1] for path in data_paths]
print(classNames)

打印结果为:['Monkeypox', 'Others']

2.图片数据处理

torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起 
1)Resize:将输入图片resize成统一尺寸
2)ToTensor:将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
3)Normalize:标准化处理-->转换为标准正态分布,使模型更容易收敛;其中mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的
ImageFolder类来创建一个数据集对象,total_data将是一个包含所有图像数据的数据集对象,可以用于训练神经网络模型

代码如下:

total_datadir = './4-data/'
train_transforms = transforms.Compose([transforms.Resize([224, 224]),   #将输入图片resize成统一尺寸
                                       transforms.ToTensor(),           #将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
                                       transforms.Normalize             #标准化处理-->转换为标准正态分布,使模型更容易收敛
                                       (mean=[0.485, 0.456, 0.406],   # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的
                                       std=[0.229, 0.224, 0.225])
                                       ])
total_data = datasets.ImageFolder(total_datadir, transform = train_transforms)
print(total_data)

print(total_data.class_to_idx)     #total_data.class_to_idx是一个存储了数据集类别和对于索引的字典

打印结果如下:

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: ./4-data/
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
 

{'Monkeypox': 0, 'Others': 1}

3.划分数据集

train_size:表示训练集大小,为总体数据长度的80%

test_size:表示测试集大小,是总体数据长度减去训练集大小

torch.utils.data.random_split():将总体数据total_data按照指定的比例大小随机划分为训练集和测试集,并将划分的结果分别赋值给train_dataset和test_dataset两个变量

代码如下:

'''
1.3 划分数据集
'''
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print(train_dataset, test_dataset)
print(train_size, test_size)

打印输出结

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值