修改PyTorch预训练模型的下载路径(以及对现有模型vgg16修改)

目录

引言        

修改PyTorch预训练模型的下载路径

对vgg模型的训练

  • 引言        

        在下载一些pytorch的训练模型的时候,往往 默认路径是 C:\Users\<username>\.cache,而我们不想下载到C盘,我们就要对下载路径进行修改,然后本文还介绍了对现有模型vgg16的修改: 包括增加一些处理(增加一个线性层)和修改一些处理(修改最后一个线性层的参数)。

  • 修改PyTorch预训练模型的下载路径

import os
os.environ['TORCH_HOME'] = 'D:\APPDATA\Data\Modules'

D:\APPDATA\Data\Modules  改为 “你要下载到的路径名”,这样就下载到指定的路径了

  • 对vgg模型的训练

  • 我们打印一下VGG16的模型:
vgg16_False = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print(vgg16_True)

VGG16训练的数据集是Image_Net ,但是太大了!!

train_data = torchvision.datasets.ImageNet("../data_img_net", split='train',
                          download=True,transform=torchvision.transforms.ToTensor())
                                          

我们就下载CIFAR10, 这个只有10个分类!

train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)

# 对vgg16模型的使用和修改 : 对vgg16_True 加一层线性层,然后从 1000变成10

vgg16_True.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_True)

我们可以看见,最后加了一层的线性层, Linear, 从1000个feature 变成了10

上面是对整个模型进行加,而如果想对classifier里面加的话我们这样写:

vgg16_True.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_True)

就变成了以下模型:

如果想对模型进行修改:

# 对现有模型的修改,而不是添加!
print(vgg16_False)
vgg16_False.classifier[6] = nn.Linear(4096, 10)
print(vgg16_False)

修改之前:

修改之后:

就到这里了,谢谢大家~~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

睡个好觉(努力提升自己版)

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值