CSPDarKNet代码实现(pytorch)

本文介绍了一种改进的DarkNet53网络——CSPDarkNet53,并提供了详细的PyTorch实现代码。该网络引入了Mish激活函数及残差块结构,通过高效的特征提取能力提升了模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CSPDarkNet53代码实现(pytorch)

import torch.nn as nn
import torch
import torch.nn.functional as F

// 激活函数Mish
class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

class BN_Conv_Mish(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(BN_Conv_Mish, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
        # 卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行激活函数(Mish)之前不会因为数据过大而导致网络性能的不稳定
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = Mish()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, chnls, inner_chnnls=None):
        super(ResidualBlock, self).__init__()
        if inner_chnnls is None:
            inner_chnnls = chnls
        self.conv1 = BN_Conv_Mish(chnls, inner_chnnls, 1)
        self.conv2 = BN_Conv_Mish(inner_chnnls, chnls, 3)
        self.bn = nn.BatchNorm2d(chnls)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(out) + x
        return Mish()(out)

class ResBodyFirst(nn.Module):
    def __init__(self, in_chnnls, out_chnls):
        super(ResBodyFirst, self).__init__()
        self.dsample = BN_Conv_Mish(in_chnnls, out_chnls, 3, 2, 1)
        self.trans_0 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
        self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
        self.block = ResidualBlock(out_chnls, out_chnls//2)
        self.trans_cat = BN_Conv_Mish(2*out_chnls, out_chnls, 1, 1, 0)

    def forward(self, x):
        x = self.dsample(x)
        out_0 = self.trans_0(x)
        out_1 = self.trans_1(x)
        out_1 = self.block(out_1)
        out = torch.cat((out_0, out_1), 1)
        out = self.trans_cat(out)
        print("第 1 次残差", out.shape)
        return out

class ResBodyStem(nn.Module):
    def __init__(self, in_chnls, out_chnls, num_block, index):
        super(ResBodyStem, self).__init__()
        self.dsample = BN_Conv_Mish(in_chnls, out_chnls, 3, 2, 1)
        self.trans_0 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)
        self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)
        self.blocks = nn.Sequential(*[ResidualBlock(out_chnls//2) for _ in range(num_block)])
        self.trans_cat = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)
        self.index = index

    def forward(self, x):
        x = self.dsample(x)
        out_0 = self.trans_0(x)
        out_1 = self.trans_1(x)
        out_1 = self.blocks(out_1)
        out = torch.cat((out_0, out_1), 1)
        out = self.trans_cat(out)
        print("第", self.index+2, "次残差", out.shape)
        return out

class CSP_DarkNet(nn.Module):
    def __init__(self, num_blocks: object, num_classes=1000) -> object:
        super(CSP_DarkNet, self).__init__()
        chnls = [64, 128, 256, 512, 1024]
        self.conv0 = BN_Conv_Mish(3, 32, 3, 1, 1)
        self.resblock1 = ResBodyFirst(32, chnls[0])
        self.resblocks = nn.Sequential(*[ResBodyStem(chnls[i], chnls[i+1], num_blocks[i], i) for i in range(4)])

    def forward(self, x):
        out = self.conv0(x)
        out = self.resblock1(out)
        out = self.resblocks(out)
        return out

def csp_darknet_53(num_classes=1000):
    return CSP_DarkNet([2, 8, 8, 4], num_classes)

net = csp_darknet_53()
// 生成一个416X416像素,3通道的图片
n = torch.rand(1, 3, 416, 416)

print(net(n).shape)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值