UNETR++模型搭建

模型结构的打印

'''
===========================================================================
Layer (type:depth-idx)                             Param #
===========================================================================
├─UnetrPPEncoder: 1-1                              --
|    └─ModuleList: 2-1                             --
|    |    └─Sequential: 3-1                        --
|    |    |    └─Convolution: 4-1                  --
|    |    |    |    └─Conv3d: 5-1                  1,024
|    |    |    └─GroupNorm: 4-2                    64
|    |    └─Sequential: 3-2                        --
|    |    |    └─Convolution: 4-3                  --
|    |    |    |    └─Conv3d: 5-2                  16,384
|    |    |    └─GroupNorm: 4-4                    128
|    |    └─Sequential: 3-3                        --
|    |    |    └─Convolution: 4-5                  --
|    |    |    |    └─Conv3d: 5-3                  65,536
|    |    |    └─GroupNorm: 4-6                    256
|    |    └─Sequential: 3-4                        --
|    |    |    └─Convolution: 4-7                  --
|    |    |    |    └─Conv3d: 5-4                  262,144
|    |    |    └─GroupNorm: 4-8                    512
|    └─ModuleList: 2-2                             --
|    |    └─Sequential: 3-5                        --ia
|    |    |    └─TransformerBlock: 4-9             --
|    |    |    |    └─LayerNorm: 5-5               64
|    |    |    |    └─EPA: 5-6                     --
|    |    |    |    |    └─Linear: 6-1             4,096
|    |    |    |    |    └─Linear: 6-2             2,097,216
|    |    |    |    |    └─Dropout: 6-3            --
|    |    |    |    |    └─Dropout: 6-4            --
|    |    |    |    |    └─Linear: 6-5             528
|    |    |    |    |    └─Linear: 6-6             528
|    |    |    |    └─UnetResBlock: 5-7            --
|    |    |    |    |    └─Convolution: 6-7        --
|    |    |    |    |    |    └─Conv3d: 7-1        27,648
|    |    |    |    |    └─Convolution: 6-8        --
|    |    |    |    |    |    └─Conv3d: 7-2        27,648
|    |    |    |    |    └─LeakyReLU: 6-9          --
|    |    |    |    |    └─BatchNorm3d: 6-10       64
|    |    |    |    |    └─BatchNorm3d: 6-11       64
|    |    |    |    └─Sequential: 5-8              --
|    |    |    |    |    └─Dropout3d: 6-12         --
|    |    |    |    |    └─Conv3d: 6-13            1,056
|    |    |    └─TransformerBlock: 4-10            --
|    |    |    |    └─LayerNorm: 5-9               64
|    |    |    |    └─EPA: 5-10                    --
|    |    |    |    |    └─Linear: 6-14            4,096
|    |    |    |    |    └─Linear: 6-15            2,097,216
|    |    |    |    |    └─Dropout: 6-16           --
|    |    |    |    |    └─Dropout: 6-17           --
|    |    |    |    |    └─Linear: 6-18            528
|    |    |    |    |    └─Linear: 6-19            528
|    |    |    |    └─UnetResBlock: 5-11           --
|    |    |    |    |    └─Convolution: 6-20       --
|    |    |    |    |    |    └─Conv3d: 7-3        27,648
|    |    |    |    |    └─Convolution: 6-21       --
|    |    |    |    |    |    └─Conv3d: 7-4        27,648
|    |    |    |    |    └─LeakyReLU: 6-22         --
|    |    |    |    |    └─BatchNorm3d: 6-23       64
|    |    |    |    |    └─BatchNorm3d: 6-24       64
|    |    |    |    └─Sequential: 5-12             --
|    |    |    |    |    └─Dropout3d: 6-25         --
|    |    |    |    |    └─Conv3d: 6-26            1,056
|    |    |    └─TransformerBlock: 4-11            --
|    |    |    |    └─LayerNorm: 5-13              64
|    |    |    |    └─EPA: 5-14                    --
|    |    |    |    |    └─Linear: 6-27            4,096
|    |    |    |    |    └─Linear: 6-28            2,097,216
|    |    |    |    |    └─Dropout: 6-29           --
|    |    |    |    |    └─Dropout: 6-30           --
|    |    |    |    |    └─Linear: 6-31            528
|    |    |    |    |    └─Linear: 6-32            528
|    |    |    |    └─UnetResBlock: 5-15           --
|    |    |    |    |    └─Convolution: 6-33       --
|    |    |    |    |    |    └─Conv3d: 7-5        27,648
|    |    |    |    |    └─Convolution: 6-34       --
|    |    |    |    |    |    └─Conv3d: 7-6        27,648
|    |    |    |    |    └─LeakyReLU: 6-35         --
|    |    |    |    |    └─BatchNorm3d: 6-36       64
|    |    |    |    |    └─BatchNorm3d: 6-37       64
|    |    |    |    └─Sequential: 5-16             --
|    |    |    |    |    └─Dropout3d: 6-38         --
|    |    |    |    |    └─Conv3d: 6-39            1,056
|    |    └─Sequential: 3-6                        --
|    |    |    └─TransformerBlock: 4-12            --
|    |    |    |    └─LayerNorm: 5-17              128
|    |    |    |    └─EPA: 5-18                    --
|    |    |    |    |    └─Linear: 6-40            16,384
|    |    |    |    |    └─Linear: 6-41            262,208
|    |    |    |    |    └─Dropout: 6-42           --
|    |    |    |    |    └─Dropout: 6-43           --
|    |    |    |    |    └─Linear: 6-44            2,080
|    |    |    |    |    └─Linear: 6-45            2,080
|    |    |    |    └─UnetResBlock: 5-19           --
|    |    |    |    |    └─Convolution: 6-46       --
|    |    |    |    |    |    └─Conv3d: 7-7        110,592
|    |    |    |    |    └─Convolution: 6-47       --
|    |    |    |    |    |    └─Conv3d: 7-8        110,592
|    |    |    |    |    └─LeakyReLU: 6-48         --
|    |    |    |    |    └─BatchNorm3d: 6-49       128
|    |    |    |    |    └─BatchNorm3d: 6-50       128
|    |    |    |    └─Sequential: 5-20             --
|    |    |    |    |    └─Dropout3d: 6-51         --
|    |    |    |    |    └─Conv3d: 6-52            4,160
|    |    |    └─TransformerBlock: 4-13            --
|    |    |    |    └─LayerNorm: 5-21              128
|    |    |    |    └─EPA: 5-22                    --
|    |    |    |    |    └─Linear: 6-53            16,384
|    |    |    |    |    └─Linear: 6-54            262,208
|    |    |    |    |    └─Dropout: 6-55           --
|    |    |    |    |    └─Dropout: 6-56           --
|    |    |    |    |    └─Linear: 6-57            2,080
|    |    |    |    |    └─Linear: 6-58            2,080
|    |    |    |    └─UnetResBlock: 5-23           --
|    |    |    |    |    └─Convolution: 6-59       --
|    |    |    |    |    |    └─Conv3d: 7-9        110,592
|    |    |    |    |    └─Convolution: 6-60       --
|    |    |    |    |    |    └─Conv3d: 7-10       110,592
|    |    |    |    |    └─LeakyReLU: 6-61         --
|    |    |    |    |    └─BatchNorm3d: 6-62       128
|    |    |    |    |    └─BatchNorm3d: 6-63       128
|    |    |    |    └─Sequential: 5-24             --
|    |    |    |    |    └─Dropout3d: 6-64         --
|    |    |    |    |    └─Conv3d: 6-65            4,160
|    |    |    └─TransformerBlock: 4-14            --
|    |    |    |    └─LayerNorm: 5-25              128
|    |    |    |    └─EPA: 5-26                    --
|    |    |    |    |    └─Linear: 6-66            16,384
|    |    |    |    |    └─Linear: 6-67            262,208
|    |    |    |    |    └─Dropout: 6-68           --
|    |    |    |    |    └─Dropout: 6-69           --
|    |    |    |    |    └─Linear: 6-70            2,080
|    |    |    |    |    └─Linear: 6-71            2,080
|    |    |    |    └─UnetResBlock: 5-27           --
|    |    |    |    |    └─Convolution: 6-72       --
|    |    |    |    |    |    └─Conv3d: 7-11       110,592
|    |    |    |    |    └─Convolution: 6-73       --
|    |    |    |    |    |    └─Conv3d: 7-12       110,592
|    |    |    |    |    └─LeakyReLU: 6-74         --
|    |    |    |    |    └─BatchNorm3d: 6-75       128
|    |    |    |    |    └─BatchNorm3d: 6-76       128
|    |    |    |    └─Sequential: 5-28             --
|    |    |    |    |    └─Dropout3d: 6-77         --
|    |    |    |    |    └─Conv3d: 6-78            4,160
|    |    └─Sequential: 3-7                        --
|    |    |    └─TransformerBlock: 4-15            --
|    |    |    |    └─LayerNorm: 5-29              256
|    |    |    |    └─EPA: 5-30                    --
|    |    |    |    |    └─Linear: 6-79            65,536
|    |    |    |    |    └─Linear: 6-80            32,832
|    |    |    |    |    └─Dropout: 6-81           --
|    |    |    |    |    └─Dropout: 6-82           --
|    |    |    |    |    └─Linear: 6-83            8,256
|    |    |    |    |    └─Linear: 6-84            8,256
|    |    |    |    └─UnetResBlock: 5-31           --
|    |    |    |    |    └─Convolution: 6-85       --
|    |    |    |    |    |    └─Conv3d: 7-13       442,368
|    |    |    |    |    └─Convolution: 6-86       --
|    |    |    |    |    |    └─Conv3d: 7-14       442,368
|    |    |    |    |    └─LeakyReLU: 6-87         --
|    |    |    |    |    └─BatchNorm3d: 6-88       256
|    |    |    |    |    └─BatchNorm3d: 6-89       256
|    |    |    |    └─Sequential: 5-32             --
|    |    |    |    |    └─Dropout3d: 6-90         --
|    |    |    |    |    └─Conv3d: 6-91            16,512
|    |    |    └─TransformerBlock: 4-16            --
|    |    |    |    └─LayerNorm: 5-33              256
|    |    |    |    └─EPA: 5-34                    --
|    |    |    |    |    └─Linear: 6-92            65,536
|    |    |    |    |    └─Linear: 6-93            32,832
|    |    |    |    |    └─Dropout: 6-94           --
|    |    |    |    |    └─Dropout: 6-95           --
|    |    |    |    |    └─Linear: 6-96            8,256
|    |    |    |    |    └─Linear: 6-97            8,256
|    |    |    |    └─UnetResBlock: 5-35           --
|    |    |    |    |    └─Convolution: 6-98       --
|    |    |    |    |    |    └─Conv3d: 7-15       442,368
|    |    |    |    |    └─Convolution: 6-99       --
|    |    |    |    |    |    └─Conv3d: 7-16       442,368
|    |    |    |    |    └─LeakyReLU: 6-100        --
|    |    |    |    |    └─BatchNorm3d: 6-101      256
|    |    |    |    |    └─BatchNorm3d: 6-102      256
|    |    |    |    └─Sequential: 5-36             --
|    |    |    |    |    └─Dropout3d: 6-103        --
|    |    |    |    |    └─Conv3d: 6-104           16,512
|    |    |    └─TransformerBlock: 4-17            --
|    |    |    |    └─LayerNorm: 5-37              256
|    |    |    |    └─EPA: 5-38                    --
|    |    |    |    |    └─Linear: 6-105           65,536
|    |    |    |    |    └─Linear: 6-106           32,832
|    |    |    |    |    └─Dropout: 6-107          --
|    |    |    |    |    └─Dropout: 6-108          --
|    |    |    |    |    └─Linear: 6-109           8,256
|    |    |    |    |    └─Linear: 6-110           8,256
|    |    |    |    └─UnetResBlock: 5-39           --
|    |    |    |    |    └─Convolution: 6-111      --
|    |    |    |    |    |    └─Conv3d: 7-17       442,368
|    |    |    |    |    └─Convolution: 6-112      --
|    |    |    |    |    |    └─Conv3d: 7-18       442,368
|    |    |    |    |    └─LeakyReLU: 6-113        --
|    |    |    |    |    └─BatchNorm3d: 6-114      256
|    |    |    |    |    └─BatchNorm3d: 6-115      256
|    |    |    |    └─Sequential: 5-40             --
|    |    |    |    |    └─Dropout3d: 6-116        --
|    |    |    |    |    └─Conv3d: 6-117           16,512
|    |    └─Sequential: 3-8                        --
|    |    |    └─TransformerBlock: 4-18            --
|    |    |    |    └─LayerNorm: 5-41              512
|    |    |    |    └─EPA: 5-42                    --
|    |    |    |    |    └─Linear: 6-118           262,144
|    |    |    |    |    └─Linear: 6-119           2,080
|    |    |    |    |    └─Dropout: 6-120          --
|    |    |    |    |    └─Dropout: 6-121          --
|    |    |    |    |    └─Linear: 6-122           32,896
|    |    |    |    |    └─Linear: 6-123           32,896
|    |    |    |    └─UnetResBlock: 5-43           --
|    |    |    |    |    └─Convolution: 6-124      --
|    |    |    |    |    |    └─Conv3d: 7-19       1,769,472
|    |    |    |    |    └─Convolution: 6-125      --
|    |    |    |    |    |    └─Conv3d: 7-20       1,769,472
|    |    |    |    |    └─LeakyReLU: 6-126        --
|    |    |    |    |    └─BatchNorm3d: 6-127      512
|    |    |    |    |    └─BatchNorm3d: 6-128      512
|    |    |    |    └─Sequential: 5-44             --
|    |    |    |    |    └─Dropout3d: 6-129        --
|    |    |    |    |    └─Conv3d: 6-130           65,792
|    |    |    └─TransformerBlock: 4-19            --
|    |    |    |    └─LayerNorm: 5-45              512
|    |    |    |    └─EPA: 5-46                    --
|    |    |    |    |    └─Linear: 6-131           262,144
|    |    |    |    |    └─Linear: 6-132           2,080
|    |    |    |    |    └─Dropout: 6-133          --
|    |    |    |    |    └─Dropout: 6-134          --
|    |    |    |    |    └─Linear: 6-135           32,896
|    |    |    |    |    └─Linear: 6-136           32,896
|    |    |    |    └─UnetResBlock: 5-47           --
|    |    |    |    |    └─Convolution: 6-137      --
|    |    |    |    |    |    └─Conv3d: 7-21       1,769,472
|    |    |    |    |    └─Convolution: 6-138      --
|    |    |    |    |    |    └─Conv3d: 7-22       1,769,472
|    |    |    |    |    └─LeakyReLU: 6-139        --
|    |    |    |    |    └─BatchNorm3d: 6-140      512
|    |    |    |    |    └─BatchNorm3d: 6-141      512
|    |    |    |    └─Sequential: 5-48             --
|    |    |    |    |    └─Dropout3d: 6-142        --
|    |    |    |    |    └─Conv3d: 6-143           65,792
|    |    |    └─TransformerBlock: 4-20            --
|    |    |    |    └─LayerNorm: 5-49              512
|    |    |    |    └─EPA: 5-50                    --
|    |    |    |    |    └─Linear: 6-144           262,144
|    |    |    |    |    └─Linear: 6-145           2,080
|    |    |    |    |    └─Dropout: 6-146          --
|    |    |    |    |    └─Dropout: 6-147          --
|    |    |    |    |    └─Linear: 6-148           32,896
|    |    |    |    |    └─Linear: 6-149           32,896
|    |    |    |    └─UnetResBlock: 5-51           --
|    |    |    |    |    └─Convolution: 6-150      --
|    |    |    |    |    |    └─Conv3d: 7-23       1,769,472
|    |    |    |    |    └─Convolution: 6-151      --
|    |    |    |    |    |    └─Conv3d: 7-24       1,769,472
|    |    |    |    |    └─LeakyReLU: 6-152        --
|    |    |    |    |    └─BatchNorm3d: 6-153      512
|    |    |    |    |    └─BatchNorm3d: 6-154      512
|    |    |    |    └─Sequential: 5-52             --
|    |    |    |    |    └─Dropout3d: 6-155        --
|    |    |    |    |    └─Conv3d: 6-156           65,792
├─UnetResBlock: 1-2                                --
|    └─Convolution: 2-3                            --
|    |    └─Conv3d: 3-9                            432
|    └─Convolution: 2-4                            --
|    |    └─Conv3d: 3-10                           6,912
|    └─LeakyReLU: 2-5                              --
|    └─InstanceNorm3d: 2-6                         --
|    └─InstanceNorm3d: 2-7                         --
|    └─Convolution: 2-8                            --
|    |    └─Conv3d: 3-11                           16
|    └─InstanceNorm3d: 2-9                         --
├─UnetrUpBlock: 1-3                                --
|    └─Convolution: 2-10                           --
|    |    └─ConvTranspose3d: 3-12                  262,144
|    └─ModuleList: 2-11                            --
|    |    └─Sequential: 3-13                       --
|    |    |    └─TransformerBlock: 4-21            --
|    |    |    |    └─LayerNorm: 5-53              256
|    |    |    |    └─EPA: 5-54                    --
|    |    |    |    |    └─Linear: 6-157           65,536
|    |    |    |    |    └─Linear: 6-158           32,832
|    |    |    |    |    └─Dropout: 6-159          --
|    |    |    |    |    └─Dropout: 6-160          --
|    |    |    |    |    └─Linear: 6-161           8,256
|    |    |    |    |    └─Linear: 6-162           8,256
|    |    |    |    └─UnetResBlock: 5-55           --
|    |    |    |    |    └─Convolution: 6-163      --
|    |    |    |    |    |    └─Conv3d: 7-25       442,368
|    |    |    |    |    └─Convolution: 6-164      --
|    |    |    |    |    |    └─Conv3d: 7-26       442,368
|    |    |    |    |    └─LeakyReLU: 6-165        --
|    |    |    |    |    └─BatchNorm3d: 6-166      256
|    |    |    |    |    └─BatchNorm3d: 6-167      256
|    |    |    |    └─Sequential: 5-56             --
|    |    |    |    |    └─Dropout3d: 6-168        --
|    |    |    |    |    └─Conv3d: 6-169           16,512
|    |    |    └─TransformerBlock: 4-22            --
|    |    |    |    └─LayerNorm: 5-57              256
|    |    |    |    └─EPA: 5-58                    --
|    |    |    |    |    └─Linear: 6-170           65,536
|    |    |    |    |    └─Linear: 6-171           32,832
|    |    |    |    |    └─Dropout: 6-172          --
|    |    |    |    |    └─Dropout: 6-173          --
|    |    |    |    |    └─Linear: 6-174           8,256
|    |    |    |    |    └─Linear: 6-175           8,256
|    |    |    |    └─UnetResBlock: 5-59           --
|    |    |    |    |    └─Convolution: 6-176      --
|    |    |    |    |    |    └─Conv3d: 7-27       442,368
|    |    |    |    |    └─Convolution: 6-177      --
|    |    |    |    |    |    └─Conv3d: 7-28       442,368
|    |    |    |    |    └─LeakyReLU: 6-178        --
|    |    |    |    |    └─BatchNorm3d: 6-179      256
|    |    |    |    |    └─BatchNorm3d: 6-180      256
|    |    |    |    └─Sequential: 5-60             --
|    |    |    |    |    └─Dropout3d: 6-181        --
|    |    |    |    |    └─Conv3d: 6-182           16,512
|    |    |    └─TransformerBlock: 4-23            --
|    |    |    |    └─LayerNorm: 5-61              256
|    |    |    |    └─EPA: 5-62                    --
|    |    |    |    |    └─Linear: 6-183           65,536
|    |    |    |    |    └─Linear: 6-184           32,832
|    |    |    |    |    └─Dropout: 6-185          --
|    |    |    |    |    └─Dropout: 6-186          --
|    |    |    |    |    └─Linear: 6-187           8,256
|    |    |    |    |    └─Linear: 6-188           8,256
|    |    |    |    └─UnetResBlock: 5-63           --
|    |    |    |    |    └─Convolution: 6-189      --
|    |    |    |    |    |    └─Conv3d: 7-29       442,368
|    |    |    |    |    └─Convolution: 6-190      --
|    |    |    |    |    |    └─Conv3d: 7-30       442,368
|    |    |    |    |    └─LeakyReLU: 6-191        --
|    |    |    |    |    └─BatchNorm3d: 6-192      256
|    |    |    |    |    └─BatchNorm3d: 6-193      256
|    |    |    |    └─Sequential: 5-64             --
|    |    |    |    |    └─Dropout3d: 6-194        --
|    |    |    |    |    └─Conv3d: 6-195           16,512
├─UnetrUpBlock: 1-4                                --
|    └─Convolution: 2-12                           --
|    |    └─ConvTranspose3d: 3-14                  65,536
|    └─ModuleList: 2-13                            --
|    |    └─Sequential: 3-15                       --
|    |    |    └─TransformerBlock: 4-24            --
|    |    |    |    └─LayerNorm: 5-65              128
|    |    |    |    └─EPA: 5-66                    --
|    |    |    |    |    └─Linear: 6-196           16,384
|    |    |    |    |    └─Linear: 6-197           262,208
|    |    |    |    |    └─Dropout: 6-198          --
|    |    |    |    |    └─Dropout: 6-199          --
|    |    |    |    |    └─Linear: 6-200           2,080
|    |    |    |    |    └─Linear: 6-201           2,080
|    |    |    |    └─UnetResBlock: 5-67           --
|    |    |    |    |    └─Convolution: 6-202      --
|    |    |    |    |    |    └─Conv3d: 7-31       110,592
|    |    |    |    |    └─Convolution: 6-203      --
|    |    |    |    |    |    └─Conv3d: 7-32       110,592
|    |    |    |    |    └─LeakyReLU: 6-204        --
|    |    |    |    |    └─BatchNorm3d: 6-205      128
|    |    |    |    |    └─BatchNorm3d: 6-206      128
|    |    |    |    └─Sequential: 5-68             --
|    |    |    |    |    └─Dropout3d: 6-207        --
|    |    |    |    |    └─Conv3d: 6-208           4,160
|    |    |    └─TransformerBlock: 4-25            --
|    |    |    |    └─LayerNorm: 5-69              128
|    |    |    |    └─EPA: 5-70                    --
|    |    |    |    |    └─Linear: 6-209           16,384
|    |    |    |    |    └─Linear: 6-210           262,208
|    |    |    |    |    └─Dropout: 6-211          --
|    |    |    |    |    └─Dropout: 6-212          --
|    |    |    |    |    └─Linear: 6-213           2,080
|    |    |    |    |    └─Linear: 6-214           2,080
|    |    |    |    └─UnetResBlock: 5-71           --
|    |    |    |    |    └─Convolution: 6-215      --
|    |    |    |    |    |    └─Conv3d: 7-33       110,592
|    |    |    |    |    └─Convolution: 6-216      --
|    |    |    |    |    |    └─Conv3d: 7-34       110,592
|    |    |    |    |    └─LeakyReLU: 6-217        --
|    |    |    |    |    └─BatchNorm3d: 6-218      128
|    |    |    |    |    └─BatchNorm3d: 6-219      128
|    |    |    |    └─Sequential: 5-72             --
|    |    |    |    |    └─Dropout3d: 6-220        --
|    |    |    |    |    └─Conv3d: 6-221           4,160
|    |    |    └─TransformerBlock: 4-26            --
|    |    |    |    └─LayerNorm: 5-73              128
|    |    |    |    └─EPA: 5-74                    --
|    |    |    |    |    └─Linear: 6-222           16,384
|    |    |    |    |    └─Linear: 6-223           262,208
|    |    |    |    |    └─Dropout: 6-224          --
|    |    |    |    |    └─Dropout: 6-225          --
|    |    |    |    |    └─Linear: 6-226           2,080
|    |    |    |    |    └─Linear: 6-227           2,080
|    |    |    |    └─UnetResBlock: 5-75           --
|    |    |    |    |    └─Convolution: 6-228      --
|    |    |    |    |    |    └─Conv3d: 7-35       110,592
|    |    |    |    |    └─Convolution: 6-229      --
|    |    |    |    |    |    └─Conv3d: 7-36       110,592
|    |    |    |    |    └─LeakyReLU: 6-230        --
|    |    |    |    |    └─BatchNorm3d: 6-231      128
|    |    |    |    |    └─BatchNorm3d: 6-232      128
|    |    |    |    └─Sequential: 5-76             --
|    |    |    |    |    └─Dropout3d: 6-233        --
|    |    |    |    |    └─Conv3d: 6-234           4,160
├─UnetrUpBlock: 1-5                                --
|    └─Convolution: 2-14                           --
|    |    └─ConvTranspose3d: 3-16                  16,384
|    └─ModuleList: 2-15                            --
|    |    └─Sequential: 3-17                       --
|    |    |    └─TransformerBlock: 4-27            --
|    |    |    |    └─LayerNorm: 5-77              64
|    |    |    |    └─EPA: 5-78                    --
|    |    |    |    |    └─Linear: 6-235           4,096
|    |    |    |    |    └─Linear: 6-236           2,097,216
|    |    |    |    |    └─Dropout: 6-237          --
|    |    |    |    |    └─Dropout: 6-238          --
|    |    |    |    |    └─Linear: 6-239           528
|    |    |    |    |    └─Linear: 6-240           528
|    |    |    |    └─UnetResBlock: 5-79           --
|    |    |    |    |    └─Convolution: 6-241      --
|    |    |    |    |    |    └─Conv3d: 7-37       27,648
|    |    |    |    |    └─Convolution: 6-242      --
|    |    |    |    |    |    └─Conv3d: 7-38       27,648
|    |    |    |    |    └─LeakyReLU: 6-243        --
|    |    |    |    |    └─BatchNorm3d: 6-244      64
|    |    |    |    |    └─BatchNorm3d: 6-245      64
|    |    |    |    └─Sequential: 5-80             --
|    |    |    |    |    └─Dropout3d: 6-246        --
|    |    |    |    |    └─Conv3d: 6-247           1,056
|    |    |    └─TransformerBlock: 4-28            --
|    |    |    |    └─LayerNorm: 5-81              64
|    |    |    |    └─EPA: 5-82                    --
|    |    |    |    |    └─Linear: 6-248           4,096
|    |    |    |    |    └─Linear: 6-249           2,097,216
|    |    |    |    |    └─Dropout: 6-250          --
|    |    |    |    |    └─Dropout: 6-251          --
|    |    |    |    |    └─Linear: 6-252           528
|    |    |    |    |    └─Linear: 6-253           528
|    |    |    |    └─UnetResBlock: 5-83           --
|    |    |    |    |    └─Convolution: 6-254      --
|    |    |    |    |    |    └─Conv3d: 7-39       27,648
|    |    |    |    |    └─Convolution: 6-255      --
|    |    |    |    |    |    └─Conv3d: 7-40       27,648
|    |    |    |    |    └─LeakyReLU: 6-256        --
|    |    |    |    |    └─BatchNorm3d: 6-257      64
|    |    |    |    |    └─BatchNorm3d: 6-258      64
|    |    |    |    └─Sequential: 5-84             --
|    |    |    |    |    └─Dropout3d: 6-259        --
|    |    |    |    |    └─Conv3d: 6-260           1,056
|    |    |    └─TransformerBlock: 4-29            --
|    |    |    |    └─LayerNorm: 5-85              64
|    |    |    |    └─EPA: 5-86                    --
|    |    |    |    |    └─Linear: 6-261           4,096
|    |    |    |    |    └─Linear: 6-262           2,097,216
|    |    |    |    |    └─Dropout: 6-263          --
|    |    |    |    |    └─Dropout: 6-264          --
|    |    |    |    |    └─Linear: 6-265           528
|    |    |    |    |    └─Linear: 6-266           528
|    |    |    |    └─UnetResBlock: 5-87           --
|    |    |    |    |    └─Convolution: 6-267      --
|    |    |    |    |    |    └─Conv3d: 7-41       27,648
|    |    |    |    |    └─Convolution: 6-268      --
|    |    |    |    |    |    └─Conv3d: 7-42       27,648
|    |    |    |    |    └─LeakyReLU: 6-269        --
|    |    |    |    |    └─BatchNorm3d: 6-270      64
|    |    |    |    |    └─BatchNorm3d: 6-271      64
|    |    |    |    └─Sequential: 5-88             --
|    |    |    |    |    └─Dropout3d: 6-272        --
|    |    |    |    |    └─Conv3d: 6-273           1,056
├─UnetrUpBlock: 1-6                                --
|    └─Convolution: 2-16                           --
|    |    └─ConvTranspose3d: 3-18                  16,384
|    └─ModuleList: 2-17                            --
|    |    └─UnetResBlock: 3-19                     --
|    |    |    └─Convolution: 4-30                 --
|    |    |    |    └─Conv3d: 5-89                 6,912
|    |    |    └─Convolution: 4-31                 --
|    |    |    |    └─Conv3d: 5-90                 6,912
|    |    |    └─LeakyReLU: 4-32                   --
|    |    |    └─InstanceNorm3d: 4-33              --
|    |    |    └─InstanceNorm3d: 4-34              --
├─UnetOutBlock: 1-7                                --
|    └─Convolution: 2-18                           --
|    |    └─Conv3d: 3-20                           238
├─UnetOutBlock: 1-8                                --
|    └─Convolution: 2-19                           --
|    |    └─Conv3d: 3-21                           462
├─UnetOutBlock: 1-9                                --
|    └─Convolution: 2-20                           --
|    |    └─Conv3d: 3-22                           910
===========================================================================
Total params: 34,643,882
Trainable params: 34,643,882
Non-trainable params: 0
===========================================================================
'''

模型代码部分搭建

模型类

class UNETR_PP(SegmentationNetwork):
    """
    UNETR++ based on: "Shaker et al.,
    UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
    """

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            img_size: [64, 128, 128],
            feature_size: int = 16,
            hidden_size: int = 256,
            num_heads: int = 4,
            pos_embed: str = "perceptron",  # TODO: Remove the argument
            norm_name: Union[Tuple, str] = "instance",
            dropout_rate: float = 0.0,
            depths=None,
            dims=None,
            conv_op=nn.Conv3d,
            do_ds=True,

    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of  the last encoder.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            dropout_rate: faction of the input units to drop.
            depths: number of blocks for each stage.
            dims: number of channel maps for the stages.
            conv_op: type of convolution operation.
            do_ds: use deep supervision to compute the loss.

        Examples::

            # for single channel input 4-channel output with patch size of (64, 128, 128), feature size of 16, batch
            norm and depths of [3, 3, 3, 3] with output channels [32, 64, 128, 256], 4 heads, and 14 classes with
            deep supervision:
            # >>> net = UNETR_PP(in_channels=1, out_channels=14, img_size=(64, 128, 128), feature_size=16, num_heads=4,
            # >>>                 norm_name='batch', depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True)
        """

        super().__init__()
        if depths is None:
            depths = [3, 3, 3, 3]
        self.do_ds = do_ds
        self.conv_op = conv_op
        self.num_classes = out_channels
        if not (0 <= dropout_rate <= 1):
            raise AssertionError("dropout_rate should be between 0 and 1.")

        if pos_embed not in ["conv", "perceptron"]:
            raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

        self.patch_size = (2, 4, 4)
        self.feat_size = (
            img_size[0] // self.patch_size[0] // 8,  # 8 is the downsampling happened through the four encoders stages
            img_size[1] // self.patch_size[1] // 8,  # 8 is the downsampling happened through the four encoders stages
            img_size[2] // self.patch_size[2] // 8,  # 8 is the downsampling happened through the four encoders stages
        )
        self.hidden_size = hidden_size

        self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)

        self.encoder1 = UnetResBlock(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 16,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=8 * 8 * 8,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=16 * 16 * 16,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=32 * 32 * 32,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=(2, 4, 4),
            norm_name=norm_name,
            out_size=64 * 128 * 128,
            conv_decoder=True,
        )
        self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
        if self.do_ds:
            self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
            self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)

    def proj_feat(self, x, hidden_size, feat_size):
        x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

    def forward(self, x_in):
        x_output, hidden_states = self.unetr_pp_encoder(x_in)

        convBlock = self.encoder1(x_in)

        # Four encoders
        enc1 = hidden_states[0]
        enc2 = hidden_states[1]
        enc3 = hidden_states[2]
        enc4 = hidden_states[3]

        # Four decoders
        dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
        dec3 = self.decoder5(dec4, enc3)
        dec2 = self.decoder4(dec3, enc2)
        dec1 = self.decoder3(dec2, enc1)

        out = self.decoder2(dec1, convBlock)
        if self.do_ds:
            logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
        else:
            logits = self.out1(out)

        return logits

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值