文章的字母中:
b: batch_size
t: time_step
n: num_feature
h: hidden_size
假设输入数据维度input = (b, t, n)
所设计的LSTM模型如下:
class MYLSTM(nn.Module):
def __init__(self, input_size, hidden_size, out_size):
super(MYLSTM, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.lstm = nn.LSTM(
input_size=self.input_size + self.hidden_size,
hidden_size=self.hidden_size,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(self.hidden_size, out_size)
def forward(self, x):
hidden, cell = Variable(torch.zeros(1, x.size(0), self.hidden_size)),\
Variable(torch.zeros(1, x.size(0), self.hidden_size))
for i in range(x.size(1)):
curx = x[: