平时用了很多Keras,训练的时候非常方便,直接model.fit就可以了。但是PyTorch的训练得自己写,这里小结下PyTorch怎么训练模型。
PyTorch训练的大体步骤
一个标准的PyTorch模型必须得有一个固定结构的类,结构如下
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(TwoLayerNet, self).__init__()
self.</