PyTorch的基础知识及使用

本文围绕PyTorch展开,介绍了其常用函数。详细讲解了数据加载,包括Dataset获取数据及label、Dataloader为网络提供数据形式。还介绍了TensorBoard绘图、Transforms对图片变换、torchvision中数据集和模型的使用,以及在Pycharm中的操作方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 前言

我们首先要知道PyTorch有两个常用函数:

(1)dir() //打开
(2)help() //帮助怎么使用,说明书

例子:dir(torch) ;
help(torch.cuda.is_available)

2 PyTorch加载数据

2.1 Dataset

Dataset: 提供一种方式去获取数据及其label(只是告诉程序数据集在一个什么位置;包括说数据集当中的第一个数据、第二个数据是什么,比如我们给它一个索引0,它对应的是哪一个数据;以及它可以告诉我们这个数据集有多少数据)。
问题(主要实现的功能):

  1. 如何获取每一个数据及其label?
  2. 告诉我们总共有多少个数据(这个的作用是:神经网络要对一个数据迭代多次,只有当我们知道我们总共有多少个数据要去训练的时候,它才知道我们要训练多少次才能把整个数据集给迭代完,然后进行下一次的迭代)。

使用方法:
read_data.py

from torch.utils.data import Dataset
#import cv2
from PIL import Image
import os

#首先创建一个类,取名为MyData,这个类需要继承Dataset这个类
class MyData(Dataset):
  #首先我们需要进行初始化
  #初始化就是我们需要根据这一个类去创建一个特例实例的时候,它就要运行的一个函数。
  #这个函数一般是为这个整个class提供一个全局变量,主要是为后面的一些函数提供一些他们所需要的量,所以这个方法可以留在后面写,你可以先尝试写一后面的函数。
     def __init__(self,root_dir,label_dir):
         #这里为什么用self?
         #我们知道一个函数当中的变量是不能传递给另外一个函数当中的变量,而这个self能把self指定的一个变量给后面的函数使用,它就相当于指定了一个类当中的全局变量。
         self.root_dir = root_dir
         self.label_dir = label_dir
         #os.path.join表示括号里面的两个路径进行相加,然后得到了所需(即对应)图片的地址
         self.path = os.path.join(self.root_dir,self.label_dir)
         #然后我们想要获取所有图片地址的一个列表List
         self.img_path = os.listdir(self.path)
         #接下来我们要获取其中的每一个图片,将在getitem函数中进行书写

  #下面这个函数的作用是如何获取(读取)到我们数据集里面的数据。比如我们的数据集是图片,那我们应该如何获取(读取)到里面的任意一张图片呢?
  #我们一般使用getitem这个函数。
  #这个函数里面的第二个变量一般默认是item,我们一般修改为idx,因为我们一般用到的是idx。
  #idx作为一个编号吧。
     #def __getitem__(self,item):
     def __getitem__(self,idx):
  #一般如何去获取(读取)里面的任意一张图片呢?
  #比如说:我们使用opencv的话,需要首先提供这个图片的地址
  #这时我们需要在前面导入opencv:import cv2
  #但是如果你不想使用或安装opencv,我们可以使用from PIL import Image这种方法来获取(读取)里面的任意一张图片。同样需要你在前面导入这句话:from PIL import Image。

#####以下就是我们使用from PIL import Image这种方法在Pycharm控制台输代码的步骤:
  #比如我们想获取的照片为图一这个项目中的数据集all中的cloudy1.jpg这张图片。
  #此时这张图片的路径我们可以通过快捷键“Ctrl+Shift+c"来直接进行获取并复制。
  #然后粘贴到”img_path=“的后面就可以了。
  #但是需要注意:这个路径在Windows系统里写的时候要把单斜杠改成双斜杠,来表示转义。
  #即我们直接进行获取并复制的路径格式为:D:\project2023\weather-recognition-main\all\cloudy\cloudy1.jpg。但正确应该修改为下面这种方式:
        #img_path = "D:\\project2023\\weather-recognition-main\\all\\cloudy\\cloudy1.jpg"
  #然后我们进行图片的获取(读取)
        #img = Image.open(img_path)
  #最后我们将这个图片进行显示
        #img.show()
####以上我们就成功实现了所需对应位置的图片,具体在Pycharm控制台输入这部分代码的步骤如图二所示。

  #所以说我们直接在Pycharm平时写代码的地方实现这个方法时,首先需要获取到img_path,即这个图片的地址。
  #我们想通过idx这个索引来获取地址。
  #我们需要首先获取所有图片地址的一个List(列表),然后通过idx这个索引中每个索引对应的不同照片就可以获取到相应的地址。
  #所以我们就需要在初始化函数中对其进行一些量的设置。
  #如果我们想要获取所有图片的地址,我们就需要使用到os,就是Python中关于系统的一个库。
  #这时我们就需要在前面导入os库:import os

  #获取所有图片地址的一个List(列表),也就是需要获取数据集这个文件夹下面的所有图片。
  #这时我们就需要在初始化函数中定义一个root_dir变量来获取数据集这个文件夹的路径。
  #因为这个数据集的label就是图片所在上一级文件夹的名称,所以我们还需要在初始化函数中定义一个label_dir变量来获取这个图片所对应的标签。
  #写完初始化函数中self.img_path = os.listdir(self.path)这句话以后我们获取其中的每一个图片就要进行下面的操作:
         img_name = self.img_path[idx]
         img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
  #获取(读取)这个图片
         img = Image.open(img_item_path)
  #对这个图片的label进行获取
         label = self.label_dir
  #最后返回这个图片和标签label
         return img,label


#下面这个函数的作用是获取这个数据集有多长。
     def __len__(self):
         return len(self.img_path)

#下面我们创建实例来看一下,比如说获取图一项目数据集中cloudy和rainy下的所有图片
root_dir = "all"
cloudy_label_dir = "cloudy"
rainy_label_dir = "rainy"
cloudy_dataset = MyData(root_dir,cloudy_label_dir)
rainy_dataset = MyData(root_dir,cloudy_label_dir)
##比如我们要获取cloudy和rainy下的第一个图片,我们可以这样写:
##img,label = cloudy_dataset[0]
##img,label = rainy_dataset[0]
##最后我们可以将获取到的cloudy下的第一个图片进行一个显示:
##img.show()
        
#我们可以注意到上面是获取数据集中单个标签里的所有图片,下面是我们获取整一个数据集的方法,也就是对于项目一中整个数据all的获取。
#假设项目一all这个数据集只有cloudy和rainy这两个标签,也就是这两个文件夹,我们可以进行如下操作:
all_dataset = cloudy_dataset + rainy_dataset
##如果我们想要看一下cloudy数据集的长度,我们可以这样写:
##len(cloudy_dataset)
##如果我们想要看一下整个数据集的长度,我们可以这样写:
##len(all_dataset)
##比如我们要获取cloudy下的第一个图片,我们可以这样写:
##img,label = all_dataset[0]
##最后我们可以将获取到的cloudy下的第一个图片进行一个显示:
##img.show()

图一:
在这里插入图片描述
图二:
在这里插入图片描述

2.2 Dataloader

Dataloader: 为后面的网络提供不同的数据形式(loader就相当于一个加载器,就是说把我们的数据加载到神经网络当中去),它所做的事情就是每次从dataset中去取数据,每次取多少?怎么取?这个过程是由Dataloader当中的参数进行设置的。
使用方法:
dataloader.py

#假如想要在项目中使用torchvision下的CIFAR(CIFAR10)这个数据集,则必须导入torchvision
import torchvision
from torch.utils.data import DataLoader
#准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transforms,download=True)
#dataset这个参数代表:我们前面自定义的dataset(dataset告诉程序数据集在一个什么位置;包括说数据集当中的第一个数据、第二个数据是什么,比如我们给它一个索引0,它对应的是哪一个数据;以及它可以告诉我们这个数据集有多少数据),这里指的就是前面准备的测试数据集test_data。
#batch_size:代表每次从数据集中抓取多少个数据,即batch_size=4的话代表每次从dataset中随机取4张图片进行一个打包。
#shuffle:代表抓取完一轮(注意这里不是一次,而是整整一轮)数据以后,数据顺序是否打乱的意思。值为True的话就要打乱,值为False的话则代表不要打乱,每次抓取的时候数据顺序都是一样的。我们一般喜欢设置为True。
#num_workers:代表我们加载数据时采用的是单进程还是多进程进行加载的。默认为0,0的话代表利用主进程进行了一个加载。但是有的时候num_workers在Windows底下会有一些问题,比如说:num_workers>0的时候在Windows底下会出现错误,这时如果报错,你可以考虑一下将num_workers设置为0,看会不会解决这些问题。
#drop_last:它主要是决定比如数据集里一共有100个数据,你每三个抓取一次,100/3会余下一个数据,对于这一个数据你要不要留下它?值为True的话就说明我们不要它了,只要前面的99张牌即可;值为False的话就说明我们要它,还是要取出这100个数据的。
#test_loader = Dataloader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
test_loader = Dataloader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)

#测试数据集中第一张图片及target
img,target = test_data[0]
#查看图片的大小,如果运行结果为torch.Size([3,32,32])代表这个图片是RGB三通道,大小是32*32。
print(img.shape)
#查看图片的target
print(target)

writer = SummaryWrite("dataloader")
##对于shuffle的理解:就是说下面的这个(for data in test_loader)for循环其实就是实现了一轮抓取,当我们在下一轮进行抓取的时候,比如说再对这个数据进行同样一次的抓取时,第2次抓取出来的数据和第1次抓取出来的数据是否是一样的。值为True的话就要打乱,两次抓取结果顺序是不一样的;值为False的话则代表不要打乱,两次抓取结果顺序是一样的。
##epoch代表轮数
##for epoch in range(2):
    ##step = 0
    ##for data in test_loader
        ##imgs,targets = data
        #print(imgs.shape)
        #这里的运行结果如果为torch.Size([4,3,32,32])代表从dataset中随机取了4张图片进行了一个打包(因为此时batch_size=4),所以第一个数字为4,这四个图片是RGB三通道,大小是32*32。
        #print(targets)
        #这里的运行结果如果为tensor([2,3,6,8])代表同时将我们上面4张图片的target进行了一个打包。
        #writer.add_images("test_data",imgs,step)
        ##writer.add_images("Epoch:{}".format(epoch),imgs,step)
        #step = step + 1

#实际的训练过程中我们习惯采用下面这种形式将imgs输送到我们的神经网络,直接作为一个神经网络的输送
step = 0
for data in test_loader
        imgs,targets = data
        #print(imgs.shape)
        #这里的运行结果如果为torch.Size([4,3,32,32])代表从dataset中随机取了4张图片进行了一个打包(因为此时batch_size=4),所以第一个数字为4,这四个图片是RGB三通道,大小是32*32。
        #print(targets)
        #这里的运行结果如果为tensor([2,3,6,8])代表同时将我们上面4张图片的target进行了一个打包。
        writer.add_images("test_data",imgs,step)
        step = step + 1
writer.close()

3 TensorBoard

TensorBoard: 用来绘制图像
训练中可以看到losses是如何变化的,通过losses我们可以知道训练过程是否按照我们预想的变化或者训练过程当中是不是属于一个正常的状态。我们也可以从相应的losses当中去看一下我们应该选择一下什么样的模型,比如说训练到多少步的时候,这个模型比较符合我们的预期。
使用方法:
test_Tensorboard.py

from torch.utils.ensorboard import SummaryWriter
import numpy as np
from PIL import Image

#创建一个SummaryWriter类的实例
#我们需要将它对应的一些事件文件存储到一个叫logs的文件夹底下
writer = SummaryWrite("logs")
#在writer这个实例当中我们主要会使用到add_image()和add_scalar()这两个方法。

    # add_image这个函数一般对应的参数形式长下面这样:
    # def add_imager(self,tag,img_tensor,global_step=None,walltime=None,dataformats='CHW'):
    #tag:代表所绘制这个图表/图像的标题,为string类型,所以要加双引号
    #img_tensor:代表图像的数据类型,这里能用的为torch.Tensor、numpy.array、string、blobname
    #global_step:代表训练步数
    #walltime=None在这个一般不常用,这里就不做讲解了
    #dataformats='CHW':设置三通道格式,C通道,H高度,W宽度

#为了解决PIL读取的图片的数据类型不符合img_tensor的要求这个问题,我们利用Opencv读取图片,获得numpy型图片数据
#这时除了安装opencv,还需要在前面进行导入:import numpy as np、from PIL import Image
#这里以读取上面图一中cloudy下的第一个图片进行一个显示:
image_path = "all/cloudy/cloudy1.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
#输出图像的数据类型
print(type(img_array))
#查看此时三通道格式
print(img_array.shape)
#如果此时三通道格式不为CHW的顺序了,那么我们就需要将其修改为对应格式,如HWC。

writer.add_image("test",img_array,1,dataformats='HWC')
#如果此时我们读取上面图一中rainy下的第一个图片进行一个显示
#我们直接把路径修改为image_path = "all/rainy/rainy1.jpg"即可
#而且如果想和上一张图绘制在一起,我们可以把训练步数变成2,即writer.add_image("test",img_array,1,dataformats='HWC'),这样运行以后就可以拖动上面的滑块来进行显示了
#如果想单独进行显示,我们直接修改一下图像标题即可完成,如:writer.add_image("test1",img_array,1,dataformats='HWC')

#绘制一个y = 2x的图像
for i in range(100):
    # add_scalar这个函数一般对应的参数形式长下面这样:
    # def add_scalar(self,tag,scalar_value,global_step=None,walltime=None):
    #tag:代表所绘制这个图表/图像的标题,为string类型,所以要加双引号
    #scalar_value:代表一个要保存的数值,也就是y轴
    #global_step:代表训练步长,也就是X轴
    #walltime=None在这个一般不常用,这里就不做讲解了
    writer.add_scalar("y=2x",2*i,i)
#然后进行运行以后你会发现项目中多了一个logs文件夹,他就是用来记录我们的事件文件的
#然后在logs文件夹下面一般就会出现一个文件名类似于下面这个形式的文件:events.out.tfevents.1665706856.DESKTOP-54MP3OH.10652.0

writer.close()

1. 想要查询SummaryWriter类如何使用该怎么办?
答:在Pycharm中查询一个类如何使用的快捷键是按住Ctrl键,把鼠标移动到这个要查询的类,此时这个类会变为蓝色,我们进行点击一下,就会看到对它的相关介绍。
2. 如何打开我们的事件文件呢(也就是logs文件夹下面出现的类似于下面这个形式的文件名的文件:events.out.tfevents.1665706856.DESKTOP-54MP3OH.10652.0)?
答:在此项目的Pycharm终端中输入tensorboard --logdir=logs,然后就会出现一个类似于TenSorboard 2.1.0 at http://localhost:6006/的句子,点击一下http://localhost:6006/这个网址进去就好,也就是这个主机所对应的6006窗口。但是有的时候一台服务器上有好几个人在训练,它默认打开的是这个6006端口的话(也就是大家的都指向同一个端口的话)就很麻烦,所以我们有的时候可以去指定一下这个端口,为了防止和别人冲突,比如说就可以 在此项目的Pycharm终端中输入tensorboard --logdir=logs --port=6007,此时就会出现6007的窗口。
注意:
(1)logdir=事件文件所在文件夹名。
(2)你可以通过port这个参数对端口号进行设置,去避免和其他人的端口一样。

4 Transforms

Transforms: 主要对图片进行一些变换
Transforms的结构:
如下面的图三所示,Transforms的学习其实是对transforms.py这个文件来进行学习。transforms.py相当于一个工具箱,里面有totensor、resize这些类(也就是定义的class),也就是相当于工具箱里面的工具。我们对transforms的使用一般就是拿一些特定格式的图片,经过工具箱里面的工具进行处理,最后输出一个我们想要的一个图片的结果。
图三:
在这里插入图片描述
使用方法:
test_Transforms.py

from torchvision import transforms

#如果不引入这个库,后面的Image.open()将会报错
from PIL import Image
from torch.utils.ensorboard import SummaryWriter
#Transforms在python当中的用法 --》tensor数据类型
#通过transforms.ToTensor去解决两个问题
#第一个问题就是:transforms在python中该如何使用?
#第二个问题是:为什么我们需要tensor数据类型(也就是相较于普通的数据类型有什么区别)?
#ToTensor这个类的作用就是将"PIL Image"或者"numpy.ndarray"这种数据类型的图片转换成"tensor"数据类型。

#记着python里面自带的Image.open()这个函数的括号里写的是图片的路径
#下面是同一个照片的绝对路径和相对路径举例(以上面图一项目中的照片为例):
#绝对路径举例:D:\project2023\weather-recognition-main\all\cloudy\cloudy1.jpg
#相对路径举例:all/cloudy/cloudy1.jpg

#这里为什么选择相对路径呢?是因为绝对路径中的\会被当成一个转义符
img_path= "all/cloudy/cloudy1.jpg"
#可以看到如果写成下面这种格式,因为转义原因就不能当成一个完整的地址了
#img_path_abs = "D:\project2023\weather-recognition-main\all\cloudy\cloudy1.jpg"
img = Image.open(img_path)
#此时可以查看一下图片的相关信息
#print(img)

writer = SummaryWrite("logs")

#第一个问题就是:transforms在python中该如何使用?
tensor_trans = transforms.ToTensor()
#下面这里可以按住Ctrl+p来看一下括号里需要什么参数
tensor_img = tensor_trans(img)
#此时可以查看一下图片的相关信息
print(tensor_img)
#此时解决了第一个问题:transforms在python中该如何使用?
#我们从transforms中选择了一个class对其进行创建,然后根据创建的这个工具看我们具体需要什么东西,这里是img,然后就可以返回出我们的一个结果。这里我们就可以将图三完善为下面的图四了,也就是详细的对于transforms的使用。

#因为tensor数据类型包装了反向神经网络所需要的一些理论基础的一些参数,如:梯度、我们使用的设备是什么等,所以我们需要tensor数据类型。
#所以说我们为什么会经常使用到ToTensor这个类呢,是因为我们在神经网络中肯定要把它转换成tensor型对其进行一个训练,所以我们百分之百会用到transforms。

#同时我们将Tensorboard利用进去,就需要在前面导入一下了:from torch.utils.ensorboard import SummaryWriter

writer.add_image("Tensor_img",tensor_img)

writer.close()

图四:在这里插入图片描述
常见的transforms:
为了更好的使用它,我们需要关注三个点,也就是下面图五中函数的输入、输出和作用。因为我们的图片会有不同的格式,如:"PIL"就是用Image.open()打开的一种方式等。
图五:在这里插入图片描述
使用方法:
test_UsefulTransforms.py

from PIL import Image
from torchvision import transforms
from torch.utils.ensorboard import SummaryWriter

writer = SummaryWrite("logs")
img = Image.open("all/cloudy/cloudy1.jpg")
print(img)

#ToTensor的使用:
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("ToTensor",img_tensor)

#Normalize的使用:
#Normalize这个类的作用就是将"tensor"这种数据类型的图片进行归一化。

#因为我们的图片是rgb的,rgb有三层(信道)。
#这里我们先对这个图片的第一个层(信道),第一行,第一列的像素进行一个输出
print(img_tensor[0][0][0])

#下面这里可以按住Ctrl+p来看一下括号里需要什么参数,我们看到需要输入均值和标准差
#因为我们的图片是rgb三层的,所以我们要加三个均值和三个标准差
#这里的0.5是人为设定的
trans_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = trans_norm(img_tensor)

#上面Normalize进行归一化的公式是input[channel]=(input[channel]-mean[channel])/std[channel]
#mean[channel]代表这个信道设置的均值
#std[channel]代表这个信道设置的标准差
#以上面进行的例子代入此归一化公式中可以得到(input-0.5)/0.5=2*input-1
#如果input像素的值为[0,1],则归一化后result的值为[-1,1]

#然后我们对这个图片进行归一化后的第一个层(信道),第一行,第一列的像素进行一个输出(此步骤是为了和前面归一化前的输出结果做一个对比,加深对归一化公式的理解)
print(img_norm[0][0][0])

#将最后归一化的一个图像结果进行一个输出,这里可以看到图像的色调发生了变化
writer.add_image("Normalize",img_norm)

#Resize的使用:
#Resize这个类的作用就是将"PIL Image"这种数据类型的图片尺寸进行修改
#把img这个照片的尺寸进行一个打印输出
print(img.size)
#将图片尺寸设置为512*512
trans_resize = transforms.Resize((512,512))
#把PIL数据类型的img 经过 resize 变成了 PIL数据类型的img_resize
img_resize = trans_resize(img)
#PIL数据类型的img_resize 经过 totensor 变成了 tensor数据类型的img_resize
img_resize = trans_totensor(img_resize)
#将最后进行尺寸变换的图像结果进行一个输出,这里可以看到图像的尺寸发生了相应的变化
writer.add_image("Resize",img_resize,0)
#可以查看出Resize的输出类型,即返回值
#print(type(img_resize))
print(img_resize)

#Compose --resize的第二种使用方法
#Resize这个类下面Compose这个用法的作用就是:如果输出的尺寸是一个int类型,就是给它一个数的话,则会对它进行一个等比缩放。也就是不改变高和宽的一个比例,只单纯改变最小的那个边和最长边它们之间的一个大小关系。

#这种方法我们输入的是一个序列,即Compose()中的参数需要是一个列表
#Python中,列表的表示形式为[数据1,数据2,...]
#在Compose中,数据需要是transforms类型,所以得到Compose([transforms参数1,transforms参数2,...])
trans_resize_2 = transforms.Resize(512)
#这个Compose的用法一定要关注它输入输出的类型:
#把PIL数据类型的img 经过 resize 变成了 PIL数据类型的img
#再把PIL数据类型的img 经过 totensor 变成了 tensor数据类型的img
#我们要注意[trans_resize_2,trans_totensor]前面这个参数输出的数据类型要能提供给后面的这个参数使用,否则就会报错。
trans_compose = transforms.Compose([trans_resize_2,trans_totensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize",img_resize_2,1)

#RandomCrop的使用
#RandomCrop这个类的作用就是将"PIL Image"这种数据类型的图片进行随机裁剪。
#trans_random = transforms.RandomCrop((500,1000))
trans_random = transforms.RandomCrop(512)
#首先我们进行一个随机裁剪,随后转换成一个tensor数据类型
trans_compose_2 = transforms.Compose([trans_random,trans_totensor])
#随便裁剪10个
for i in range(10):
    #这里的参数用的是img,是因为RandomCrop的输入数据类型要求是"PIL Image"
    img_crop = trans_compose_2(img)
    writer.add_image("RandomCrop",img_crop,i)
    
writer.close()

使用方法的总结:
(1)我们需要关注这个类输入和输出的数据类型,多看官方文档。
(2)关注方法需要什么参数。
(3)不知道返回值的时候可以用 print(type()、print()、debug三种方法得到。
Python中_call_的用法:
CallTest.py

class Person:
    #这个__下划线表示的是一些内置函数
    def __call__(self,name):
        print("__call__"+" Hello " + name)
    
    def hello(self,name):
        print(" Hello " + name)

person = Person()
person("zhangsan")
person.hello("lisi")

5 torchvision

torchvision中的数据集使用:
如图六所示,在PyTorch的官网中我们可以看到Docs下面对PyTorch分了不同的块,比如:torchaudio是处理语音的、torchtext是处理文本的、torchvision是处理图像的等。
我们点开torchvision,里面torchvision datasets下面就是PyTorch给我们提供的一个相关数据集的API文档,只要写代码的时候指定这些相应的数据集,给他设置些参数,它就能自己去下载使用这些标准的数据集。
我们点开torchvision,里面torchvision.models下面会提供一些常见的神经网络,比如说它有用于一些分类的模型、一些分割的模型、目标检测的模型、视频分类的模型等,这些神经网络有的已经与训练好了。
图六:在这里插入图片描述
使用方法:
dataset_transforms.py

#假如想要在项目中使用torchvision下的CIFAR(CIFAR10)这个数据集
import torchvision
#root就是说:这个数据集应该存放在什么位置,我们在这里使用一个相对路径,就直接将它"./dataset"。它就会在这个项目下直接创建一个dataset的文件夹将数据集存放进去。
#train的值如果为True的话就说明创建的这个数据集为一个训练集,为False的话就说明创建的这个数据集是一个测试集。
#transform:如果我们想对数据集当中的所有数据进行一个什么样的变化呢,我们把transform写在这个括号里即可。
#target_transform:它是对target进行一个transform。
#download:如果设置为True的话就会给我们自动下载这个数据集,如果设置为False的话就不会给我们自动下载这个数据集
#下载CIFAR(CIFAR10)这个数据集当训练集
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
#下载CIFAR(CIFAR10)这个数据集当测试集
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True

#查看测试集中的第一个图片
print(test_set[0])
#classes对应的是数据集里面的类别,也就是标签
print(test_set.classes)

img,target = test_set[0]
print(img)
#target就是classes标签里面的第几个,比如target为3,对应的那个标签是cat
print(target)
print(test_set.classes[target])
#将这个图片进行一个展示
img.show()

前面仅仅讲解了torchvision中dataset的部分使用,接下来需要与transforms进行联动:
因为我们这里原始的图片是"PIL Image"数据类型,如果说我们需要对PyTorch进行一个使用时,我们需要把图片的数据类型转换成"tensor",这时就需要使用到transforms了。
对上面的dataset_transforms.py进行以下补充:

#假如想要在项目中使用torchvision下的CIFAR(CIFAR10)这个数据集
import torchvision
from torch.utils.ensorboard import SummaryWriter
#与transforms的联动
dataset_transforms = torchvision.transforms.Compse([
    #将图片的数据类型转换成了tensor 
    torchvision.transforms.ToTensor()
    #torchvision.transforms.ToTensor(),
    #接下来对数据进行裁剪啥的也行,我们这里暂时只做了将数据类型转换成tensor 这个操作
    ])
#root就是说:这个数据集应该存放在什么位置,我们在这里使用一个相对路径,就直接将它"./dataset"。它就会在这个项目下直接创建一个dataset的文件夹将数据集存放进去。
#train的值如果为True的话就说明创建的这个数据集为一个训练集,为False的话就说明创建的这个数据集是一个测试集。
#transform:如果我们想对数据集当中的所有数据进行一个什么样的变化呢,我们把transform写在这个括号里即可。
#target_transform:它是对target进行一个transform。
#download:如果设置为True的话就会给我们自动下载这个数据集,如果设置为False的话就不会给我们自动下载这个数据集
#下载CIFAR(CIFAR10)这个数据集当训练集
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transforms,download=True)
#下载CIFAR(CIFAR10)这个数据集当测试集
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transforms,download=True)

#查看测试集中的第一个图片
#print(test_set[0])
#classes对应的是数据集里面的类别,也就是标签
#print(test_set.classes)

#img,target = test_set[0]
#print(img)
#target就是classes标签里面的第几个,比如target为3,对应的那个标签是cat
#print(target)
#print(test_set.classes[target])
#将这个图片进行一个展示
#img.show()

#查看测试集中的第一个图片
#print(test_set[0])

writer = SummaryWrite("logs")
#显示这个测试数据集的前10张图片
for i in range(10):
    img,target = test_set[i]
    writer.add_image("test_set",img,i)

writer.close()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值