我使用的是pytorch1.5 的版本
这里主要涉及到两个函数
官网的解释 https://pytorch.org/docs/1.5.0/nn.html
save保存模型
torch.save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=False)
obj : 你要保存的对象
f : 保存的文件位置
后面的参数不管,默认就行
举例
x = torch.tensor([0, 1, 2, 3, 4])
torch.save(x, 'tensor.pt')
load加载模型
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
f : 文件对应的位置
map_location : 描述torch的device,加载的时候的容器保存位置
举例:
torch.load('tensors.pt', map_location=torch.device('cpu'))
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
保存的时候要
要用state_dict()这个函数保存模型,再save
如:
state={‘net’:model.state_dict(),‘optimizer’:optimiser.state_dict(),‘epoch’:epochs}
torch.save(state,‘D:\AllPro\PytorchPro\two.pth’)
加载的时候反之
torch.load(‘D:\AllPro\PytorchPro\two.pth’)
model.state_dict() 返回一个model的完整状态的字典
print(model.state_dict())
打印结果
OrderedDict([('Lone.weight', tensor([[1.1060],
[1.8362],
[1.8172],
[1.3367],
[1.3266]])), ('Lone.bias', tensor([ 0.0952, -1.8990, -0.1769, -0.6908, -0.9089])), ('Ltwo.weight', tensor([[-0.5234, 0.0282, -0.0736, -0.1586, 0.0362],
[ 2.3027, 1.9011, 1.6350, 2.2779, 1.6443],
[ 2.0606, 1.4825, 1.5095, 1.4385, 1.8779],
[ 2.1344, 1.5597, 1.5797, 2.0561, 1.7084],
[ 1.9351, 1.7267, 1.4738, 1.5796, 1.4055]])), ('Ltwo.bias', tensor([-0.0395, -0.5296, -0.2373, -0.5782, -1.2260])), ('Lthr.weight', tensor([[-0.2432, 1.7701, 1.8104, 1.8243, 1.6312]])), ('Lthr.bias', tensor([-0.6451]))])
print(model) model直接打印就只有网络信息
MyLinnerClass(
(Lone): Linear(in_features=1, out_features=5, bias=True)
(Ltwo): Linear(in_features=5, out_features=5, bias=True)
(Lthr): Linear(in_features=5, out_features=1, bias=True)
)
其他的细节可以查看原文档
这里对上次的回归函数做一下实验
import torch.nn as nn
import torch
import os
import time
import torch.nn.functional as Fun
class MyLinnerClass(nn.Module):
def __init__(self,inputDim, outputDim):
super(MyLinnerClass,self).__init__()
self.Lone = torch.nn.Linear(inputDim,5*inputDim)
self.Ltwo = torch.nn.Linear(5*inputDim,5*inputDim)
self.Lthr = torch.nn.Linear(5*inputDim,outputDim)
def forward(self,x):
tempOne = self.Lone(x)
tempOne = Fun.relu(tempOne)
tempTwo = self.Ltwo(tempOne)
tempTwo = Fun.relu(tempTwo)
tempThr = self.Lthr(tempTwo)
return tempThr
def dataSetCreate(PointNumer=500,div=0.2):
x = torch.arange(0, PointNumer*div, div)
print(x.shape)
x = torch.unsqueeze(x,dim=1)
print(x.shape)
noise = (torch.randn(PointNumer)) % 15 / 10
notCurY = 1.2 * x*x + noise #在这里其实是不标准的y,故意产生一个有误差的标签
lableY = 1.2 * x*x
return x, notCurY, lableY
def main():
x, notCurY, lableY = dataSetCreate(500,0.2)
model=MyLinnerClass(1,1)
epochs=2000
learningRate = 0.002
optimiser = torch.optim.Adam(model.parameters(),lr=learningRate) #优化函数
certerion = nn.MSELoss() #loss值
for oneEpoch in range(epochs):
oneEpoch += 1
optimiser.zero_grad()
train_outY = model(x)
#loss = certerion(train_outY,lableY)
loss = certerion(train_outY,notCurY) #当标签使用的第二个非准确的 loss会有波动
loss.backward()
optimiser.step()
if(oneEpoch%50==0):
print("loss is :",loss.item())
#方法一:使用字典的方式保存关键数据然后加载
state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs}
torch.save(state,'D:\\AllPro\\PytorchPro\\two.pth')
print("train over")
#方法二 :只保存模型
#torch.save(model,'D:\\AllPro\\PytorchPro\\thr.pth')
def evalModel():
evModel = MyLinnerClass(1,1)
#方法一:使用字典的方式保存关键数据然后加载
#使用 checkpoint 取出保存的字典里面的东西
checkpoint = torch.load('D:\\AllPro\\PytorchPro\\two.pth')
#针对字典的内容一一赋值到你需要的地方
evModel.load_state_dict(checkpoint['net'])
optimiser = torch.optim.Adam(evModel.parameters(), lr=0.001)
#优化参数赋值
optimiser.load_state_dict(checkpoint['optimizer'])
#epoch的轮数赋值
epoch=checkpoint['epoch']
print(epoch)
#eval的目的是固定参数,这样传入的数据推理不会影响下一次
evModel.eval()
##方法二 直接加载模型 但是模型没有完整的训练数据
#evModel = torch.load('D:\\AllPro\\PytorchPro\\thr.pth')
#evModel.eval()
x, notCurY, lableY = dataSetCreate(500,0.2) #数据输入的shape还是要保持一致
print("evl Val is",evModel(x))
print("over")
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()
evalModel()
使用DataParallel 或者 DistDataParalle 加载的模型
使用DP和DDP加载的模型那么他在load前也是要DDP和DP定义过的
model = nn.DataParallel(net, device_ids=device_ids)
import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import DnCNN
from utils import *
from torch.nn.parallel import DistributedDataParallel as DDP
import apex
from apex import amp
def main():
#模型定义和设备初始化
#deviceType='npu:0'
deviceType='cpu'
net = DnCNN(channels=1, num_of_layers=17)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids)
pathPth="/home/wangyuanming/DnCNN-PyTorch/netG.pth"
#加载模型
checkpoint = torch.load(pathPth)
#下面使用 [net] 的前提是你在保存模型的时候是用的 state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs} 的字典保存的
model.load_state_dict(checkpoint['net'])
print("model get")
#读出图片
file = "/home/wangyuanming/DnCNN-PyTorch/data/Set12/01.png"
Img = cv2.imread(file)
Img = np.float32(Img[:,:,0])/255
Img = np.expand_dims(Img, 0)
Img = np.expand_dims(Img, 1)
ISource = torch.Tensor(Img)
noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=15/255.)
print("data ready")
ISource ,noise = ISource.to(deviceType),noise.to(deviceType)
INoisy = ISource + noise
model = model.to(deviceType)
model.eval()
print("model evl")
with torch.no_grad(): # this can save much memory
Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
psnr = batch_PSNR(Out,ISource,1.)
print("psnr {:.3f}".format(psnr))
if __name__ == "__main__":
main()
我们可以看看源代码我这里是如何保存的
state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs}
对应的若保存的时候不是字典
保存
torch.save(model.state_dict(), os.path.join(opt.outf, 'netH.pth'))
加载
pathPth="/home/wangyuanming/DnCNN-PyTorch/logs/netH.pth"
#加载模型
checkpoint = torch.load(pathPth)
model.load_state_dict(checkpoint)
print("model get")