Pytorch叨叨叨系列

本文深入探讨PyTorch中的模块构建与实例化,包括线性回归模型的实现,以及如何通过DataLoader进行数据预处理。文章还介绍了自定义数据集类,并展示了如何使用PyTorch进行数据加载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch中的模块介绍

1. 模块类的构建

class Linear(nn.Module):
	def __init__(self,param1,param2,...):
		super(Model,self).__init__()
		# 由param1,param2等等,来定义子模块
	def forward(self,x):
		# 根据子模块计算返回张量

例如:实现线性回归

import torch.nn as nn
import torch
class Linear(nn.Module):
    def __init__(self,ndim):
        super(Linear,self).__init__()
        self.ndim = ndim # 特征维数
        self.weight = nn.Parameter(torch.randn(ndim,1))
        self.bias = nn.Parameter(torch.randn(1))
    def forward(self,x):
        # y = wx+b
        return x.matmul(self.weight) + self.bias
data = torch.randn(6,64)
print(data.shape)
model = Linear(64)
y = model(data)
print(y.shape) 

2. 类的实例化和方法调用

  • named_parameters()和parameters()方法获取模型参数
for label,data in model.named_parameters():
    print(label,data)
print(model.named_parameters())

在这里插入图片描述

  • train和eval方法进行模型训练和测试的转换
model.train()
  • 使用named_children方法和children方法获取模型的子模块
  • named_modules和modules得到所有模块

3.python中的数据输入和预处理

1. 数据载入类

DataLoader(dataset,batch_szie=1,shuffle=False,sample=None,batch_sample=None,num_workers=0,collate_fn=None)
dataset:torch.utils.data.Dataset类的实例
batch_size:迷你批次的大小
shuffle:数据会不会被随机打乱
sample:自定义的采样器,python的迭代器,每次迭代的时候会返回一个数据的下标索引
batch_sampler:类似于sample,minibatch的下标索引
num_workers:数据载入器使用的进程数目,默认为0

2.映射类型的数据集

为了能够使用DataLoader类,首先需要构造关于单个数据的torch.untils.data.Dataset类。这个类有两种,一种是映射类型(map−style)(map-style)mapstyle

class Dataset(object):
	def __getitem__(self,index):
		# index:数据索引
		# 返回数据张量
	def __len__(self):
		# 返回数据的数目

3. torchvision的DatasetFolder类

数据集存储在一个目录下,每个目录有很多子目录,子目录的个数是图片类的数目,每个子目录下都存储有很多图片,且这些图片都属于一类。

4. 实际操作来啦!

import torch
import numpy as np
from torch.utils.data import DataLoader

class GetLoader(torch.utils.data.Dataset):
    def __init__(self,data_root,data_label):
        self.data = data_root
        self.label = data_label
    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]
        return data,label
    def __len__(self):
        return len(self.data)
source_data = np.random.rand(100,20)
source_label = np.random.randint(0,2,(100,1))
torch_data = GetLoader(source_label,source_label)
data = DataLoader(dataset=torch_data,batch_size=8,shuffle=False)
for batchidx,(data,target) in enumerate(data):
    print(batchidx)
    print(data)
    print(target)
    break


Reference
[1] https://www.freesion.com/article/3728236956/
[2] 深入浅出Pytorch从模型到源码

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值