Pytorch实现RNN:
一、torch.nn.RNN()
1. 导入所需包
import torch
from torch.autograd import Variable
from torch.utils import data
import torchvision
from torchvision.datasets import mnist
2. 获取数据
2.1定义数据集预处理的方法
data_tf = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5],[0.5])]
)
2.2 获取数据,同时对其进行预处理
data_path = r'C:\Users\liev\Desktop\myproject\yin_test\MNIST_DATA_PyTorch'
train_data = mnist.MNIST(data_path,train=True,transform=data_tf,download=False)
test_data = mnist.MNIST(data_path,train=False,transform=data_tf,download=False)
3.网络结构
class RNNnet(torch.nn.Module):
def __init__(self):
super(RNNnet, self).__init__()
self.rnn1 = torch.nn.RNN(784,100,3,nonlinearity='relu')
self.rnn2 = torch.nn.RNN(100,10,1,nonlinearity='relu')
def forward(self, x):
x = self.rnn1(x)
x = torch.Tensor(x