【医疗图像分割】UNETR++论文笔记及代码跑通实践

在医疗图像分割任务中,transformer模型获得了巨大的成功,UNETR提出了efficient paired attention (EPA) 模块,利用了空间和通道注意力来有效地学习通道和空间的特征,该模型在Synapse,BTCV,ACDC,BRaTs数据集上都获得了很好地效果。

论文:UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

代码:https://github.com/Amshaker/unetr_plus_plus

tips:博主有低价AutoDL RTX4090资源,有需要可以联系。

一、论文笔记

首先看一下模型架构,整体还是UNet结构,在其中引入了提出的EPA模块。

该论文的核心就是EPA模块,EPA的提出主要是解决2个问题:

1、计算更有效率:传统的self-attention计算成本很高,对于3D的医疗图像来说更高,EPA将self-attention的K和V投影到低纬度再计算,降低了计算复杂度;

2、增强了空间和通道特征表示能力:transformer本身就是一种空间注意力机制,但是它忽略了通道特征,EPA将空间和通道特征融合在了一起。

再仔细看一下EPA的结构图,上方蓝底部分式空间注意力,下方绿底部分式通道注意力。再空间注意部分,为了降低self-attention计算量,将HWDXC的K和V降维到pXC维度。

代码如下(类中的self.EF用于降低K和V的维度,空间注意力和通道注意力的K和Q是共享的):


class EPA(nn.Module):
    """
        Efficient Paired Attention Block, based on: "Shaker et al.,
        UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
        """
    def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
        self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)

        # E and F are projection matrices with shared weights used in spatial attention module to project
        # keys and values from HWD-dimension to P-dimension
        self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size)))

        self.attn_drop = nn.Dropout(channel_attn_drop)
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)

    def forward(self, x):
        B, N, C = x.shape

        qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
        qkvv = qkvv.permute(2, 0, 3, 1, 4)
        q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]

        q_shared = q_shared.transpose(-2, -1)
        k_shared = k_shared.transpose(-2, -1)
        v_CA = v_CA.transpose(-2, -1)
        v_SA = v_SA.transpose(-2, -1)

        proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args)
        k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF)))

        q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
        k_shared = torch.nn.functional.normalize(k_shared, dim=-1)

        attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA)
        x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)

        attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop_2(attn_SA)
        x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)

        return x_CA + x_SA

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'temperature', 'temperature2'}

二、代码实践

官方给出了Synapse,BTCV,ACDC,BRaTs数据集的跑通实例,我这里只跑一个BRaTs数据集,其他的是一样的步骤。

1、安装环境

使用conda安装环境:

conda create --name unetr_pp python=3.10
conda activate unetr_pp

安装torch:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

安装依赖:

pip install -r requirements.txt

2、准备数据集

官方给出了处理好的数据集地址,直接下载即可:

数据集链接
SynapseOneDrive
ACDCOneDrive
Decathon-LungOneDrive
BRaTsOneDrive

本文下载好了BraTs数据集作为实例,将其放入以下目录:

3、训练

因为只是跑通一下,把unetr_plus_plus/unetr_pp/training/network_training/unetr_pp_trainer_tumor.py中的epoch改成10:

训练就非常简单了,进入训练集脚本目录并运行脚本:

cd training_scripts
bash run_training_tumor.sh

训练起来了:

4、评估

首先将自己训练的权重放到指定位置(原来output_tumor的unetr_pp文件夹放到unetr_plus_plus\unetr_pp\evaluation\unetr_pp_tumor_checkpoint里面去):

修改代码unetr_plus_plus/unetr_pp/inference/predict.py,共有两处:

进入评估脚本目录并运行脚本:

cd evaluation_scripts

修改run_evaluation_tumor.sh脚本,相关路径替换为自己的路径(自带的脚本我没成功,大家可以自行尝试):

#!/bin/sh

DATASET_PATH=../DATASET_Tumor

export PYTHONPATH=.././
export RESULTS_FOLDER=../unetr_pp/evaluation/unetr_pp_tumor_checkpoint
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw

# Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without training

python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference/predict_simple.py -i ../DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o ../unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres  -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumor


python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference_tumor.py 0

修改unetr_plus_plus/unetr_pp/inference_tumor.py的数据集路径,可以根据自己的情况改:

运行脚本:

bash run_evaluation_tumor.sh

在推理结果的目录unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/下多了一个

dice_five.txt文件,里面有相关精度,如下(因为就训练了10个epoch,效果不行):

本文到此结束。

03-08
### UNETR 架构 UNETR 是一种基于 Transformer 的新型网络架构,专门用于高效的 3D 医学图像分割任务[^3]。该模型结合了 U-Net 和 Transformer 的优势特性: #### 1. 三维感知能力 传统的二维卷积神经网络难以有效处理复杂的三维医学影像数据,而 UNETR过引入完整的三维感受野来解决这个问题。这种设计使得模型能够更好地保持原始的空间信息,避免因降维带来的细节损失。 #### 2. Transformer 核心组件 不同于经典的 CNN 结构依赖局部特征提取的方式,UNETR 使用了强大的自注意力机制作为主要计算单元。这允许模型在整个输入体素范围内建立全局关联,从而显著提高了对于复杂解剖结构的理解能力和边界识别精度。 ```python import torch.nn as nn class SelfAttention(nn.Module): def __init__(self, dim_in, num_heads=8): super().__init__() self.num_heads = num_heads head_dim = dim_in // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim_in, dim_in * 3) self.proj_out = nn.Linear(dim_in, dim_in) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.chunk(3, dim=0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, C) return self.proj_out(out) ``` #### 3. 高效融合策略 为了进一步提高效率并减少参数量,UNETR 还采用了轻量化的设计思路,在编码器部分采用稀疏采样技术,并在解码阶段逐步恢复分辨率。这样的安排既保证了足够的表达力又降低了训练难度。 #### 4. 医疗领域定制化调整 考虑到实际应用场景的需求差异较大,开发团队针对不同类型的病变进行了针对性改进。比如增加了多尺度上下文建模模块以适应各种尺寸的目标;加入了对抗学习框架用来缓解类不平衡现象等措施,最终实现了更加鲁棒可靠的预测效果。 --- ### 应用案例展示 UNETR 已经被成功应用于多个重要的临床场景当中,特别是在 CT 或 MRI 扫描得到的大规模体积数据集上的表现尤为突出。具体来说: - **肿瘤检测**:可以精准地标记出恶性细胞群的位置范围,辅助医生做出诊断决策; - **脑部结构分割**:有助于深入探究大脑内部细微变化规律,支持神经系统疾病的早期预警; - **心脏功能评估**:过对心肌壁厚度及运动轨迹的精确测量,为心血管健康状况提供科学依据。 这些进步不仅改善了医疗服务的质量水平,同时也促进了基础科学研究的发展进程。
评论 41
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

justld

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值