paper:ConDSeg: A General Medical Image Segmentation Framework via Contrast-Driven Feature Enhancement
Code:https://github.com/Mengqi-Lei/ConDSeg
1、Contrast-Driven Feature Aggregation module
医学图像中普遍存在共现现象,例如内镜息肉图像中,大小相似的息肉往往会同时出现。现有模型容易学习到与目标实体本身无关的共现特征,导致在目标实体单独出现时无法准确预测。而本论文提出一种 对比度驱动的特征聚合模块(Contrast-Driven Feature Aggregation module),CDFA 模块是 ConDSeg 框架中用于克服医学图像分割中“共现”现象的关键模块。它通过引导多尺度特征融合和关键特征增强,从而帮助模型更好地区分目标实体和复杂背景环境。
CDFA 模块通过利用解耦出的前景和背景特征,引导多尺度特征融合。通过对比前景和背景特征,CDFA 模块能够增强关键特征,从而帮助模型更好地区分目标实体和背景。
实现过程:
- 特征预增强: 首先,将编码器输出的 f1 到 f4 特征图通过一系列扩张卷积层进行预增强,得到 e1 到 e4 特征图。
- 特征融合: 将 e1 到 e4 特征图与对应尺度的特征图进行拼接,得到 F 特征图。
- 对比特征提取: 将 ffg 和 fbg 特征图通过卷积层和双线性上采样调整尺寸,使其与 F 特征图匹配。
- 注意力计算: 在每个空间位置,CDFA 模块计算一个 K×K 窗口内的注意力权重,该权重综合考虑了前景和背景特征。
- 特征加权: 使用计算得到的注意力权重对 F 特征图进行加权,从而增强关键特征。
- 特征聚合: 将加权后的特征图进行密集聚合,得到最终的输出特征图。
Contrast-Driven Feature Aggregation module 结构图:
2、代码实现
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class CBR(nn.Module):
def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, stride=1, act=True):
super().__init__()
self.act = act
self.conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False, stride=stride),
nn.BatchNorm2d(out_c)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.act == True:
x = self.relu(x)
return x
class ContrastDrivenFeatureAggregation(nn.Module):
def __init__(self, in_c, dim, num_heads=8, kernel_size=3, padding=1, stride=1,
attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.v = nn.Linear(dim, dim)
self.attn_fg = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_bg = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
self.input_cbr = nn.Sequential(
CBR(in_c, dim, kernel_size=3, padding=1),
CBR(dim, dim, kernel_size=3, padding=1),
)
self.output_cbr = nn.Sequential(
CBR(dim, dim, kernel_size=3, padding=1),
CBR(dim, dim, kernel_size=3, padding=1),
)
def forward(self, x, fg, bg):
x = self.input_cbr(x)
x = x.permute(0, 2, 3, 1)
fg = fg.permute(0, 2, 3, 1)
bg = bg.permute(0, 2, 3, 1)
B, H, W, C = x.shape
v = self.v(x).permute(0, 3, 1, 2)
v_unfolded = self.unfold(v).reshape(B, self.num_heads, self.head_dim,
self.kernel_size * self.kernel_size,
-1).permute(0, 1, 4, 3, 2)
attn_fg = self.compute_attention(fg, B, H, W, C, 'fg')
x_weighted_fg = self.apply_attention(attn_fg, v_unfolded, B, H, W, C)
v_unfolded_bg = self.unfold(x_weighted_fg.permute(0, 3, 1, 2)).reshape(B, self.num_heads, self.head_dim,
self.kernel_size * self.kernel_size,
-1).permute(0, 1, 4, 3, 2)
attn_bg = self.compute_attention(bg, B, H, W, C, 'bg')
x_weighted_bg = self.apply_attention(attn_bg, v_unfolded_bg, B, H, W, C)
x_weighted_bg = x_weighted_bg.permute(0, 3, 1, 2)
out = self.output_cbr(x_weighted_bg)
return out
def compute_attention(self, feature_map, B, H, W, C, feature_type):
attn_layer = self.attn_fg if feature_type == 'fg' else self.attn_bg
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
feature_map_pooled = self.pool(feature_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
attn = attn_layer(feature_map_pooled).reshape(B, h * w, self.num_heads,
self.kernel_size * self.kernel_size,
self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)
attn = attn * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
return attn
def apply_attention(self, attn, v, B, H, W, C):
x_weighted = (attn @ v).permute(0, 1, 4, 3, 2).reshape(
B, self.dim * self.kernel_size * self.kernel_size, -1)
x_weighted = F.fold(x_weighted, output_size=(H, W), kernel_size=self.kernel_size,
padding=self.padding, stride=self.stride)
x_weighted = self.proj(x_weighted.permute(0, 2, 3, 1))
x_weighted = self.proj_drop(x_weighted)
return x_weighted
if __name__ == '__main__':
x = torch.randn(4, 128, 32, 32).cuda()
y = torch.randn(4, 128, 32, 32).cuda()
z = torch.randn(4, 128, 32, 32).cuda()
model = ContrastDrivenFeatureAggregation(128, 128).cuda()
out = model(x, y, z)
print(out.shape)