pytorch中的模块介绍
1. 模块类的构建
class Linear(nn.Module):
def __init__(self,param1,param2,...):
super(Model,self).__init__()
# 由param1,param2等等,来定义子模块
def forward(self,x):
# 根据子模块计算返回张量
例如:实现线性回归
import torch.nn as nn
import torch
class Linear(nn.Module):
def __init__(self,ndim):
super(Linear,self).__init__()
self.ndim = ndim # 特征维数
self.weight = nn.Parameter(torch.randn(ndim,1))
self.bias = nn.Parameter(torch.randn(1))
def forward(self,x):
# y = wx+b
return x.matmul(self.weight) + self.bias
data = torch.randn(6,64)
print(data.shape)
model = Linear(64)
y = model(data)
print(y.shape)
2. 类的实例化和方法调用
- named_parameters()和parameters()方法获取模型参数
for label,data in model.named_parameters():
print(label,data)
print(model.named_parameters())
- train和eval方法进行模型训练和测试的转换
model.train()
- 使用named_children方法和children方法获取模型的子模块
- named_modules和modules得到所有模块
3.python中的数据输入和预处理
1. 数据载入类
DataLoader(dataset,batch_szie=1,shuffle=False,sample=None,batch_sample=None,num_workers=0,collate_fn=None)
dataset:torch.utils.data.Dataset类的实例
batch_size:迷你批次的大小
shuffle:数据会不会被随机打乱
sample:自定义的采样器,python的迭代器,每次迭代的时候会返回一个数据的下标索引
batch_sampler:类似于sample,minibatch的下标索引
num_workers:数据载入器使用的进程数目,默认为0
2.映射类型的数据集
为了能够使用DataLoader类,首先需要构造关于单个数据的torch.untils.data.Dataset类。这个类有两种,一种是映射类型(map−style)(map-style)(map−style) 。
class Dataset(object):
def __getitem__(self,index):
# index:数据索引
# 返回数据张量
def __len__(self):
# 返回数据的数目
3. torchvision的DatasetFolder类
数据集存储在一个目录下,每个目录有很多子目录,子目录的个数是图片类的数目,每个子目录下都存储有很多图片,且这些图片都属于一类。
4. 实际操作来啦!
import torch
import numpy as np
from torch.utils.data import DataLoader
class GetLoader(torch.utils.data.Dataset):
def __init__(self,data_root,data_label):
self.data = data_root
self.label = data_label
def __getitem__(self, index):
data = self.data[index]
label = self.label[index]
return data,label
def __len__(self):
return len(self.data)
source_data = np.random.rand(100,20)
source_label = np.random.randint(0,2,(100,1))
torch_data = GetLoader(source_label,source_label)
data = DataLoader(dataset=torch_data,batch_size=8,shuffle=False)
for batchidx,(data,target) in enumerate(data):
print(batchidx)
print(data)
print(target)
break
Reference
[1] https://www.freesion.com/article/3728236956/
[2] 深入浅出Pytorch从模型到源码