目录
-
引言
在下载一些pytorch的训练模型的时候,往往 默认路径是 C:\Users\<username>\.cache,而我们不想下载到C盘,我们就要对下载路径进行修改,然后本文还介绍了对现有模型vgg16的修改: 包括增加一些处理(增加一个线性层)和修改一些处理(修改最后一个线性层的参数)。
-
修改PyTorch预训练模型的下载路径
import os
os.environ['TORCH_HOME'] = 'D:\APPDATA\Data\Modules'
D:\APPDATA\Data\Modules 改为 “你要下载到的路径名”,这样就下载到指定的路径了
-
对vgg模型的训练
- 我们打印一下VGG16的模型:
vgg16_False = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print(vgg16_True)
VGG16训练的数据集是Image_Net ,但是太大了!!
train_data = torchvision.datasets.ImageNet("../data_img_net", split='train',
download=True,transform=torchvision.transforms.ToTensor())
我们就下载CIFAR10, 这个只有10个分类!
train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
# 对vgg16模型的使用和修改 : 对vgg16_True 加一层线性层,然后从 1000变成10
vgg16_True.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_True)
我们可以看见,最后加了一层的线性层, Linear, 从1000个feature 变成了10
上面是对整个模型进行加,而如果想对classifier里面加的话我们这样写:
vgg16_True.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_True)
就变成了以下模型:
如果想对模型进行修改:
# 对现有模型的修改,而不是添加!
print(vgg16_False)
vgg16_False.classifier[6] = nn.Linear(4096, 10)
print(vgg16_False)
修改之前:
修改之后:
就到这里了,谢谢大家~~