import torch.nn as nn
import torch
import torch.nn.init as init
import torch.nn.functional as F
def _weights_init(m):
classname = m.__class__.__name__
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock_s(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock_s, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
def conv1x1(in_planes, planes, stride=1):
return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
def branchBottleNeck(channel_in, channel_out, kernel_size):
middle_channel = channel_out // 4
return nn.Sequential(
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
nn.BatchNorm2d(middle_channel),
nn.ReLU(),
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
nn.BatchNorm2d(middle_channel),
nn.ReLU(),
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
nn.BatchNorm2d(channel_out),
nn.ReLU(),
)
class ResNet_modify(nn.Module):
def __init__(self, block, num_blocks, num_classes=100, nf=64):
super(ResNet_modify, self).__init__()
self.in_planes = nf
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_planes)
self.layer1 = self._make_layer(block, 1 * nf, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 2 * nf, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 4 * nf, num_blocks[2], stride=2)
self.out_dim = 4 * nf * block.expansion
self.fc = nn.Linear(self.out_dim, num_classes)
hidden_dim = 128
self.fc_cb = nn.Linear(self.out_dim, num_classes)
self.contrast_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
)
self.projection_head = nn.Sequential(
nn.Linear(self.out_dim, hidden_dim),
)
# self.downsample1_1 = nn.Sequential(
# conv1x1(1 * nf * block.expansion, 4 * nf * block.expansion, stride=8),
# nn.BatchNorm2d(4 * nf * block.expansion),
# )
# self.bottleneck1_1 = branchBottleNeck(1 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=8)
# self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
# self.middle_fc1 = nn.Linear(4 * nf * block.expansion, num_classes)
# self.downsample2_1 = nn.Sequential(
# conv1x1(2 * nf * block.expansion, 4 * nf * block.expansion, stride=4),
# nn.BatchNorm2d(4 * nf * block.expansion),
# )
# self.bottleneck2_1 = branchBottleNeck(2 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=4)
# self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
# self.middle_fc2 = nn.Linear(4 * nf * block.expansion, num_classes)
# self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.apply(_weights_init)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, train=False, Use_be_your_own_teacher = False):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
if Use_be_your_own_teacher is True:
middle_output1 = self.bottleneck1_1(x)
middle_output1 = self.avgpool1(middle_output1)
middle1_fea = middle_output1
middle_output1 = torch.flatten(middle_output1, 1)
middle_output1 = self.middle_fc1(middle_output1)
x = self.layer2(x)
if Use_be_your_own_teacher is True:
middle_output2 = self.bottleneck2_1(x)
middle_output2 = self.avgpool2(middle_output2)
middle2_fea = middle_output2
middle_output2 = torch.flatten(middle_output2, 1)
middle_output2 = self.middle_fc2(middle_output2)
x = self.layer3(x)
x = F.avg_pool2d(x, x.size()[3])
feature = x.view(x.size(0), -1)
final_fea = feature
if train:
out = self.fc(feature)
out_cb = self.fc_cb(feature)
z = self.projection_head(feature)
p = self.contrast_head(z)
if Use_be_your_own_teacher is True:
return out, out_cb, z, p, middle_output1, middle_output2, final_fea, middle1_fea, middle2_fea
else:
return out, out_cb, z, p
else:
out = self.fc_cb(feature)
return out
def resnet32(num_classes=10):
return ResNet_modify(BasicBlock_s, [5, 5, 5], num_classes=num_classes)
注意我注释了这段代码:
self.downsample1_1 = nn.Sequential(
conv1x1(1 * nf * block.expansion, 4 * nf * block.expansion, stride=8),
nn.BatchNorm2d(4 * nf * block.expansion),
)
self.bottleneck1_1 = branchBottleNeck(1 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=8)
self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
self.middle_fc1 = nn.Linear(4 * nf * block.expansion, num_classes)
self.downsample2_1 = nn.Sequential(
conv1x1(2 * nf * block.expansion, 4 * nf * block.expansion, stride=4),
nn.BatchNorm2d(4 * nf * block.expansion),
)
self.bottleneck2_1 = branchBottleNeck(2 * nf * block.expansion, 4 * nf * block.expansion, kernel_size=4)
self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
self.middle_fc2 = nn.Linear(4 * nf * block.expansion, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
先后在模型里注释了这些代码对模型的初始化参数也会造成影响
添加或移除代码(即使这些代码在 forward
方法中并未实际使用)都可能会影响模型的参数初始化过程。以下是详细解释:
-
模型参数初始化顺序: 当模型被实例化时,所有子模块(例如卷积层、全连接层、批量归一化层等)会按其在代码中的出现顺序进行初始化。这包括为每个子模块分配内存并初始化其权重和偏置。如果你添加或移除了一些子模块(例如注释掉的分支网络),即使这些子模块在
forward
方法中不被使用,它们仍然会被实例化和初始化。因此,添加或移除子模块会改变模型中各层的初始化顺序,从而影响整个网络的参数初始化。 -
随机数生成器状态: 在模型参数初始化过程中,通常会用到随机数生成器来设置权重(例如使用 Kaiming 初始化)。这些随机数生成器的状态是动态变化的,因此添加或移除任何一个子模块都会影响随机数生成器的状态,导致后续层的权重初始化产生不同的随机数。
举个例子,如果你有两段代码,其中一段包含了一些注释掉的分支网络,而另一段没有,那么:
- 在包含注释掉的分支网络的代码中,即使这些分支网络在
forward
方法中未被实际使用,它们的权重初始化过程仍然会消耗随机数生成器的状态。 - 因此,当你继续初始化其他层的权重时,随机数生成器的状态已经不同了,这会导致这些层的权重初始化与另一段代码中的初始化结果不同。
为了确保两段代码在参数初始化上的完全一致性,可以通过以下方法之一进行验证和调整: