简单的 PyTorch CNN 二分类器示例

本文展示了如何使用PyTorch构建一个简单的卷积神经网络(CNN)模型,用于猫狗分类任务。通过对比TensorFlow实现,讨论了PyTorch在数据处理、模型初始化、训练与推理等方面的特点,并提供了完整的代码实现。

学完了CNN的基本构件,看完了用TensorFlow实现的CNN,让我们再用PyTorch来搭建一个CNN,并用这个网络完成之前那个简单的猫狗分类任务。

这份PyTorch实现会尽量和TensorFlow实现等价。同时,我也会分享编写此项目过程中发现的PyTorch与TensorFlow的区别。

项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/BasicCNN

获取数据集

和之前几次的代码实战任务一样,我们这次还用的是Kaggle上的猫狗数据集。我已经写好了数据预处理的函数。使用如下的接口即可获取数据集:

train_X, train_Y, test_X, test_Y = get_cat_set(
    'dldemos/LogisticRegression/data/archive/dataset',
    train_size=1500,
    format='nchw')
print(train_X.shape)  # (m, 3, 224, 224)
print(train_Y.shape)  # (m, 1)

这次的数据格式和之前项目中的有一些区别。

在使用全连接网络时,每一个输入样本都是一个一维向量。之前在预处理数据集时,我做了一个flatten操作,把图片的所有颜色值塞进了一维向量中。而在CNN中,对于卷积操作,每一个输入样本都是一个三维张量。用OpenCV读取完图片后,不用对图片Resize,直接拿过来用就可以了。

另外,在用NumPy实现时,我们把数据集大小N当作了最后一个参数;在用TensorFlow时,张量格式是"NHWC(数量-高度-宽度-通道数)“。而PyTorch中默认的张量格式是"NCHW(数量-通道数-高度-宽度)”。因此,在预处理数据集时,我令format='nchw'

初始化模型

根据课堂里讲的CNN构建思路,我搭了一个这样的网络。

由于这个二分类任务比较简单,我在设计时尽可能让可训练参数更少。刚开始用一个大步幅、大卷积核的卷积快速缩小图片边长,之后逐步让图片边长减半、深度翻倍。

这样一个网络用PyTorch实现如下:

def init_model(device='cpu'):
    model = nn.Sequential(nn.Conv2d(3, 16, 11, 3), nn.BatchNorm2d(16),
                          nn.ReLU(True), nn.MaxPool2d(2, 2),
                          nn.Conv2d(16, 32, 5), nn.BatchNorm2d(32),
                          nn.ReLU(True), nn.MaxPool2d(2, 2),
                          nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
                          nn.ReLU(True), nn.Conv2d(64, 64, 3),
                          nn.BatchNorm2d(64), nn.ReLU(True),
                          nn.MaxPool2d(2, 2), nn.Flatten(),
                          nn.Linear(3136, 2048), nn.ReLU(True),
                          nn.Linear(2048, 1), nn.Sigmoid()).to(device
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值