在语义分割中,预训练模型一般是3通道的,但是有时候我们需要改变输入通道数,方法如下:
首先加载预训练模型,打印模型第一层
import torchvision.models as models
backbone = models.resnet101(pretrained=False)
print(backbone.conv1)
打印得到第一层的结构
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
可以看到输入模型的通道数为3,只要修改这个3就可以了,若期望的输入是9通道,则进行如下操作即可:
backbone.conv1= nn.Conv2d(9, 64, kernel_size=7, stride=2, padding=3,bias=False)