prtorch.数据的导入与导出

本文介绍如何使用PyTorch加载Fashion-MNIST数据集,并展示如何自定义数据集类以便更好地处理图像数据。此外,文章还讲解了如何使用DataLoader批量加载数据,并对数据进行必要的转换。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我们在整体代码中展示了怎么导入Fashion-MNIST这个dataset,这份数据包含六万个训练数据和一万个测试数据,每份数据都有一个28*28的灰度图,你可以把他们分成十个类。

我们导入FashionMNIST Dataset需要使用到以下的参数

root = 'data' 
# 数据存储的路径

train=True 
# 是训练数据还是测试数据

download=True 
# 从互联网上下载数据 否则数据在根目录不存在

transform = ToTensor 
# 和target_transform一样,指定训练数据的转换器,将数据格式转换为tensor
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

从training_data中拿到一组数据

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

'''
使用matplotlib来可视化训练的一些样本
'''
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    # 产生一个随机数
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    # 拿到一个随机的照片和照片的标签
    img, label = training_data[sample_idx]
    # 添加照片
    figure.add_subplot(rows, cols, i)
    # 添加照片的名称
    plt.title(labels_map[label])
    # 关闭坐标轴
    plt.axis("off")
    # 在相应位置绘图
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

创建我们自己的数据导入器,一个DataSet的类必须包含三个函数 __init__``__len__``__getitem__

'''
为文件创建自定义数据集 DataSet
'''

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    # 这个函数在类实例化时初始化一次
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 使用pd读取csv文件
        self.img_labels = pd.read_csv(annotations_file)
        # img的文件目录
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    # 返回img的数量,这个函数会被python自带的len()调用
    def __len__(self):
        return len(self.img_labels)

    # 返回第idx个数据
    def __getitem__(self, idx):
        # 获取img的路径
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # 使用read_img将img转换为tensor
        image = read_image(img_path)
        # 这个是数据的标签 也就是分类的实际结果
        label = self.img_labels.iloc[idx, 1]
        # 如果我们定义了transform 则使用transform来处理图像
        if self.transform:
            image = self.transform(image)
        # 如果我们定义了target_transform 则使用target_transform来处理图像的标签
        if self.target_transform:
            label = self.target_transform(label)
        # 返回图像和标签
        return image, label

使用我们自定义的DataSet

'''
DataSet每次返回一组数据(一个图像和一个对应的标签),但是在训练模型时我们希望一次导入多个数据,
这里我们解释一下为什么一次我们经常看到导入数据量为64,32,100,而不是全部的数据呢,其实只要们的计算机内存和显卡内存足够大,是可以的,但是往往我们的计算机内存和显卡内存都是有限的,我们一次导入过多的数据就是out of memery,这样还好,但是你想想你要是训练了好几天out of memery 那不直接背过气去,所以这个值我们要设置的合理,既不要太大,也不要太小,太小的话会导致数据导入太慢
并且每次导入数据都进行重新洗牌,来防止过拟合现象,并且使用多线程来加速导入的过程
'''
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

我们还发现,现在有两个参数我们是没有自己设置的,就是transform和target_transform

torchvision.Transforms模块提供了几个开箱即用的常用转换

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    # 将整数转换为一个one-hot编码的tensor
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

for i in range(10):
    print(torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(i), value=1))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LyaJpunov

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值