模型初始化的讲究

import torch.nn as nn
import torch
import torch.nn.init as init
import torch.nn.functional as F

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

class BasicBlock_s(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock_s, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

def conv1x1(in_planes, planes, stride=1):
    return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)

def branchBottleNeck(channel_in, channel_out, kernel_size):
    middle_channel = channel_out // 4
    return nn.Sequential(
        nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
        nn.BatchNorm2d(middle_channel),
        nn.ReLU(),
        nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
        nn.BatchNorm2d(middle_channel),
        nn.ReLU(),
        nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
        nn.BatchNorm2d(channel_out),
        nn.ReLU(),
    )

class ResNet_modify(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100, nf=64):
        super(ResNet_modify, self).__init__()
        self.in_planes = nf

        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.layer1 = self._make_layer(block, 1 * nf, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 2 * nf, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 4 * nf, num_blocks[2], stride=2)
        self.out_dim = 4 * nf * block.expansion

        self.fc = nn.Linear(self.out_dim, num_classes)
        hidden_dim = 128
        self.fc_cb = nn.Linear(self.out_dim, num_classes)
        self.contrast_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.projection_head = nn.Sequential(
            nn.Linear(self.out_dim, hidden_dim),
        )

        # self.downsample1_1 = nn.Sequential(
        #     conv1x1(1 * nf * block.expansion, 4 * nf * block.expansion, stride=8),
        #     nn.BatchNorm2d(4 * nf * block.expansion),
        # )
        # self.bottleneck1_1 = branchBottleNeck(1 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=8)
        # self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
        # self.middle_fc1 = nn.Linear(4 * nf * block.expansion, num_classes)

        # self.downsample2_1 = nn.Sequential(
        #     conv1x1(2 * nf * block.expansion, 4 * nf * block.expansion, stride=4),
        #     nn.BatchNorm2d(4 * nf * block.expansion),
        # )
        # self.bottleneck2_1 = branchBottleNeck(2 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=4)
        # self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
        # self.middle_fc2 = nn.Linear(4 * nf * block.expansion, num_classes)

        # self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, train=False, Use_be_your_own_teacher = False):
        x = F.relu(self.bn1(self.conv1(x)))
        
        x = self.layer1(x)
        if Use_be_your_own_teacher  is True:
            middle_output1 = self.bottleneck1_1(x)
            middle_output1 = self.avgpool1(middle_output1)
            middle1_fea = middle_output1
            middle_output1 = torch.flatten(middle_output1, 1)
            middle_output1 = self.middle_fc1(middle_output1)

        x = self.layer2(x)
        if Use_be_your_own_teacher  is True:
            middle_output2 = self.bottleneck2_1(x)
            middle_output2 = self.avgpool2(middle_output2)
            middle2_fea = middle_output2
            middle_output2 = torch.flatten(middle_output2, 1)
            middle_output2 = self.middle_fc2(middle_output2)

        x = self.layer3(x)
        x = F.avg_pool2d(x, x.size()[3])
        feature = x.view(x.size(0), -1)

        final_fea = feature

        if train:
            out = self.fc(feature)
            out_cb = self.fc_cb(feature)
            z = self.projection_head(feature)
            p = self.contrast_head(z)
            if Use_be_your_own_teacher is True:
                return out, out_cb, z, p, middle_output1, middle_output2, final_fea, middle1_fea, middle2_fea
            else:
                return out, out_cb, z, p
        else:
            out = self.fc_cb(feature)
            return out

def resnet32(num_classes=10):
    return ResNet_modify(BasicBlock_s, [5, 5, 5], num_classes=num_classes)

注意我注释了这段代码:

        self.downsample1_1 = nn.Sequential(
            conv1x1(1 * nf * block.expansion, 4 * nf * block.expansion, stride=8),
            nn.BatchNorm2d(4 * nf * block.expansion),
        )
        self.bottleneck1_1 = branchBottleNeck(1 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=8)
        self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
        self.middle_fc1 = nn.Linear(4 * nf * block.expansion, num_classes)

        self.downsample2_1 = nn.Sequential(
            conv1x1(2 * nf * block.expansion, 4 * nf * block.expansion, stride=4),
            nn.BatchNorm2d(4 * nf * block.expansion),
        )
        self.bottleneck2_1 = branchBottleNeck(2 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=4)
        self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
        self.middle_fc2 = nn.Linear(4 * nf * block.expansion, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

先后在模型里注释了这些代码对模型的初始化参数也会造成影响

添加或移除代码(即使这些代码在 forward 方法中并未实际使用)都可能会影响模型的参数初始化过程。以下是详细解释:

  1. 模型参数初始化顺序: 当模型被实例化时,所有子模块(例如卷积层、全连接层、批量归一化层等)会按其在代码中的出现顺序进行初始化。这包括为每个子模块分配内存并初始化其权重和偏置。如果你添加或移除了一些子模块(例如注释掉的分支网络),即使这些子模块在 forward 方法中不被使用,它们仍然会被实例化和初始化。因此,添加或移除子模块会改变模型中各层的初始化顺序,从而影响整个网络的参数初始化。

  2. 随机数生成器状态: 在模型参数初始化过程中,通常会用到随机数生成器来设置权重(例如使用 Kaiming 初始化)。这些随机数生成器的状态是动态变化的,因此添加或移除任何一个子模块都会影响随机数生成器的状态,导致后续层的权重初始化产生不同的随机数。

举个例子,如果你有两段代码,其中一段包含了一些注释掉的分支网络,而另一段没有,那么:

  • 在包含注释掉的分支网络的代码中,即使这些分支网络在 forward 方法中未被实际使用,它们的权重初始化过程仍然会消耗随机数生成器的状态。
  • 因此,当你继续初始化其他层的权重时,随机数生成器的状态已经不同了,这会导致这些层的权重初始化与另一段代码中的初始化结果不同。

为了确保两段代码在参数初始化上的完全一致性,可以通过以下方法之一进行验证和调整:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值