项目介绍:图像分类项目的最小可用骨架--代码细节讲解

讲解代码如下:

'''创建数据集的类'''
import torch
from torch.utils.data import Dataset,DataLoader     #用于处理数据集的
import numpy as np
from PIL import Image       #
from torchvision import transforms      #对数据进行处理工具  转换

data_transforms = {     #字典
    'train':
        transforms.Compose([        #  对图片做预处理的。组合,
        transforms.Resize([256,256]),   #数据进行改变大小[256,256]
        transforms.ToTensor(),          #数据转换为tensor,默认把通道维度放在前面
    ]),
    'valid':
        transforms.Compose([
        transforms.Resize([256,256]),
        transforms.ToTensor(),

    ]),
}#数组增强,
#Dataset是用来处理数据的。   包含1张图片的数据(tensor),包含标签结果,并且能通过索引得到数据
class food_dataset(Dataset):       #   food_dataset是自己创建的类名称,可以改为你需要的名称
    def __init__(self, file_path,transform=None): #类的初始化,解析数据文件txt
        self.file_path = file_path
        self.imgs = []#存储图片的路径
        self.labels = []#存储图片的标签结果
        self.transform = transform
        with open(self.file_path) as f:#是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在 self.labels
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)  #图像的路径
                self.labels.append(label)   #标签,还不是tensor
#初始化:把图片目录加载到self,
    def __len__(self):  #类实例化对象后,可以使用len函数测量对象的个数   ls=[12,3,4,4]  len(training_data)
        return len(self.imgs)
    #training_data[1]
    def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签
        image = Image.open(self.imgs[idx])   #读取到图片数据,还不是tensor,BGR
        if self.transform:                   #将pil图像数据转换为tensor
            image = self.transform(image)    #图像处理为256*256,转换为tenor

        label = self.labels[idx]        #label还不是tensor
        label = torch.from_numpy(np.array(label,dtype = np.int64))  #label也转换为tensor,
        return image, label
#training_data包含了本次需要训练的全部数据集?
training_data = food_dataset(file_path = 'train.txt',transform = data_transforms['train'])  
test_data = food_dataset(file_path = 'test.txt',transform = data_transforms['valid'])

#training_data需要具备索引的功能,还要确保数据是tensor
train_dataloader = DataLoader(training_data, batch_size=64,shuffle=True)#64张图片为一个包,
test_dataloader = DataLoader(test_data, batch_size=64,shuffle=True)
  • 数据索引功能的必要性

    • 在机器学习训练中,data loader需通过索引批量获取数据(如每次取64个样本)。若数据对象不具备索引功能,则无法被data loader有效管理。因此,需为数据对象添加

      __getitem__方法,使其支持索引操作。

          def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签
              image = Image.open(self.imgs[idx])   #读取到图片数据,还不是tensor,BGR
              if self.transform:                   #将pil图像数据转换为tensor
                  image = self.transform(image)    #图像处理为256*256,转换为tenor
      
              label = self.labels[idx]        #label还不是tensor
              label = torch.from_numpy(np.array(label,dtype = np.int64))  #label也转换为tensor,
              return image, label
  • 图像预处理流程

    • 裁剪图片:统一将图片尺寸调整为256×256,避免因输入尺寸不一致导致全连接层维度冲突。
    • 数据转换:将图像数据转换为PyTorch的tensor格式,便于神经网络处理。
          def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签
              image = Image.open(self.imgs[idx])   #读取到图片数据,还不是tensor,BGR
              if self.transform:                   #将pil图像数据转换为tensor
                  image = self.transform(image)    #图像处理为256*256,转换为tenor
      
              label = self.labels[idx]        #label还不是tensor
              label = torch.from_numpy(np.array(label,dtype = np.int64))  #label也转换为tensor,
              return image, label
      
  • Pillow库的使用场景

    • Pillow库与OpenCV类似,但功能更简单,适合基础图像操作(如裁剪、格式转换)。因其轻量高效,常用于无需复杂算法的预处理任务。
          def __init__(self, file_path,transform=None): #类的初始化,解析数据文件txt
              self.file_path = file_path
              self.imgs = []#存储图片的路径
              self.labels = []#存储图片的标签结果
              self.transform = transform
              with open(self.file_path) as f:#是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在 self.labels
                  samples = [x.strip().split(' ') for x in f.readlines()]
                  for img_path, label in samples:
                      self.imgs.append(img_path)  #图像的路径
                      self.labels.append(label)   #标签,还不是tensor
  • 自定义数据集类实现

    • 继承自父类DataSet,核心方法__getitem__根据索引返回预处理后的数据(tensor格式的图片及标签)。
    • 初始化阶段读取txt文件,存储图片路径和标签,为后续索引提供依据。
      #training_data包含了本次需要训练的全部数据集?
      training_data = food_dataset(file_path = 'train.txt',transform = data_transforms['train'])  #
      test_data = food_dataset(file_path = 'test.txt',transform = data_transforms['valid'])
  • 标签处理逻辑

    • 原始标签为字符串类型,需先转为NumPy数组,再转为整数类型的tensor,确保与模型输入匹配。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张子夜 iiii

感谢支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值