CSPDarkNet53代码实现(pytorch)
import torch.nn as nn
import torch
import torch.nn.functional as F
// 激活函数Mish
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class BN_Conv_Mish(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(BN_Conv_Mish, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
# 卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行激活函数(Mish)之前不会因为数据过大而导致网络性能的不稳定
self.bn = nn.BatchNorm2d(out_channels)
self.activation = Mish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, chnls, inner_chnnls=None):
super(ResidualBlock, self).__init__()
if inner_chnnls is None:
inner_chnnls = chnls
self.conv1 = BN_Conv_Mish(chnls, inner_chnnls, 1)
self.conv2 = BN_Conv_Mish(inner_chnnls, chnls, 3)
self.bn = nn.BatchNorm2d(chnls)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.bn(out) + x
return Mish()(out)
class ResBodyFirst(nn.Module):
def __init__(self, in_chnnls, out_chnls):
super(ResBodyFirst, self).__init__()
self.dsample = BN_Conv_Mish(in_chnnls, out_chnls, 3, 2, 1)
self.trans_0 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
self.block = ResidualBlock(out_chnls, out_chnls//2)
self.trans_cat = BN_Conv_Mish(2*out_chnls, out_chnls, 1, 1, 0)
def forward(self, x):
x = self.dsample(x)
out_0 = self.trans_0(x)
out_1 = self.trans_1(x)
out_1 = self.block(out_1)
out = torch.cat((out_0, out_1), 1)
out = self.trans_cat(out)
print("第 1 次残差", out.shape)
return out
class ResBodyStem(nn.Module):
def __init__(self, in_chnls, out_chnls, num_block, index):
super(ResBodyStem, self).__init__()
self.dsample = BN_Conv_Mish(in_chnls, out_chnls, 3, 2, 1)
self.trans_0 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)
self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)
self.blocks = nn.Sequential(*[ResidualBlock(out_chnls//2) for _ in range(num_block)])
self.trans_cat = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
self.index = index
def forward(self, x):
x = self.dsample(x)
out_0 = self.trans_0(x)
out_1 = self.trans_1(x)
out_1 = self.blocks(out_1)
out = torch.cat((out_0, out_1), 1)
out = self.trans_cat(out)
print("第", self.index+2, "次残差", out.shape)
return out
class CSP_DarkNet(nn.Module):
def __init__(self, num_blocks: object, num_classes=1000) -> object:
super(CSP_DarkNet, self).__init__()
chnls = [64, 128, 256, 512, 1024]
self.conv0 = BN_Conv_Mish(3, 32, 3, 1, 1)
self.resblock1 = ResBodyFirst(32, chnls[0])
self.resblocks = nn.Sequential(*[ResBodyStem(chnls[i], chnls[i+1], num_blocks[i], i) for i in range(4)])
def forward(self, x):
out = self.conv0(x)
out = self.resblock1(out)
out = self.resblocks(out)
return out
def csp_darknet_53(num_classes=1000):
return CSP_DarkNet([2, 8, 8, 4], num_classes)
net = csp_darknet_53()
// 生成一个416X416像素,3通道的图片
n = torch.rand(1, 3, 416, 416)
print(net(n).shape)