(即插即用模块-特征处理部分) 三十一、(2024) CDFA 对比度驱动的特征聚合模块

在这里插入图片描述

paper:ConDSeg: A General Medical Image Segmentation Framework via Contrast-Driven Feature Enhancement

Code:https://github.com/Mengqi-Lei/ConDSeg


1、Contrast-Driven Feature Aggregation module

医学图像中普遍存在共现现象,例如内镜息肉图像中,大小相似的息肉往往会同时出现。现有模型容易学习到与目标实体本身无关的共现特征,导致在目标实体单独出现时无法准确预测。而本论文提出一种 对比度驱动的特征聚合模块(Contrast-Driven Feature Aggregation module),CDFA 模块是 ConDSeg 框架中用于克服医学图像分割中“共现”现象的关键模块。它通过引导多尺度特征融合和关键特征增强,从而帮助模型更好地区分目标实体和复杂背景环境。

CDFA 模块通过利用解耦出的前景和背景特征,引导多尺度特征融合。通过对比前景和背景特征,CDFA 模块能够增强关键特征,从而帮助模型更好地区分目标实体和背景。

实现过程:

  1. 特征预增强: 首先,将编码器输出的 f1 到 f4 特征图通过一系列扩张卷积层进行预增强,得到 e1 到 e4 特征图。
  2. 特征融合: 将 e1 到 e4 特征图与对应尺度的特征图进行拼接,得到 F 特征图。
  3. 对比特征提取: 将 ffg 和 fbg 特征图通过卷积层和双线性上采样调整尺寸,使其与 F 特征图匹配。
  4. 注意力计算: 在每个空间位置,CDFA 模块计算一个 K×K 窗口内的注意力权重,该权重综合考虑了前景和背景特征。
  5. 特征加权: 使用计算得到的注意力权重对 F 特征图进行加权,从而增强关键特征。
  6. 特征聚合: 将加权后的特征图进行密集聚合,得到最终的输出特征图。

Contrast-Driven Feature Aggregation module 结构图:
在这里插入图片描述


2、代码实现

import torch
import math
import torch.nn as nn
import torch.nn.functional as F


class CBR(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, stride=1, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False, stride=stride),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x


class ContrastDrivenFeatureAggregation(nn.Module):
    def __init__(self, in_c, dim, num_heads=8, kernel_size=3, padding=1, stride=1,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.head_dim = dim // num_heads

        self.scale = self.head_dim ** -0.5

        self.v = nn.Linear(dim, dim)
        self.attn_fg = nn.Linear(dim, kernel_size ** 4 * num_heads)
        self.attn_bg = nn.Linear(dim, kernel_size ** 4 * num_heads)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
        self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)

        self.input_cbr = nn.Sequential(
            CBR(in_c, dim, kernel_size=3, padding=1),
            CBR(dim, dim, kernel_size=3, padding=1),
        )
        self.output_cbr = nn.Sequential(
            CBR(dim, dim, kernel_size=3, padding=1),
            CBR(dim, dim, kernel_size=3, padding=1),
        )

    def forward(self, x, fg, bg):
        x = self.input_cbr(x)

        x = x.permute(0, 2, 3, 1)
        fg = fg.permute(0, 2, 3, 1)
        bg = bg.permute(0, 2, 3, 1)

        B, H, W, C = x.shape

        v = self.v(x).permute(0, 3, 1, 2)

        v_unfolded = self.unfold(v).reshape(B, self.num_heads, self.head_dim,
                                            self.kernel_size * self.kernel_size,
                                            -1).permute(0, 1, 4, 3, 2)
        attn_fg = self.compute_attention(fg, B, H, W, C, 'fg')

        x_weighted_fg = self.apply_attention(attn_fg, v_unfolded, B, H, W, C)

        v_unfolded_bg = self.unfold(x_weighted_fg.permute(0, 3, 1, 2)).reshape(B, self.num_heads, self.head_dim,
                                                                               self.kernel_size * self.kernel_size,
                                                                               -1).permute(0, 1, 4, 3, 2)
        attn_bg = self.compute_attention(bg, B, H, W, C, 'bg')

        x_weighted_bg = self.apply_attention(attn_bg, v_unfolded_bg, B, H, W, C)

        x_weighted_bg = x_weighted_bg.permute(0, 3, 1, 2)

        out = self.output_cbr(x_weighted_bg)

        return out

    def compute_attention(self, feature_map, B, H, W, C, feature_type):

        attn_layer = self.attn_fg if feature_type == 'fg' else self.attn_bg
        h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)

        feature_map_pooled = self.pool(feature_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

        attn = attn_layer(feature_map_pooled).reshape(B, h * w, self.num_heads,
                                                      self.kernel_size * self.kernel_size,
                                                      self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)
        attn = attn * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        return attn

    def apply_attention(self, attn, v, B, H, W, C):

        x_weighted = (attn @ v).permute(0, 1, 4, 3, 2).reshape(
            B, self.dim * self.kernel_size * self.kernel_size, -1)
        x_weighted = F.fold(x_weighted, output_size=(H, W), kernel_size=self.kernel_size,
                            padding=self.padding, stride=self.stride)
        x_weighted = self.proj(x_weighted.permute(0, 2, 3, 1))
        x_weighted = self.proj_drop(x_weighted)
        return x_weighted


if __name__ == '__main__':
    x = torch.randn(4, 128, 32, 32).cuda()
    y = torch.randn(4, 128, 32, 32).cuda()
    z = torch.randn(4, 128, 32, 32).cuda()
    model = ContrastDrivenFeatureAggregation(128, 128).cuda()
    out = model(x, y, z)
    print(out.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值