模型结构的打印
'''
===========================================================================
Layer (type:depth-idx) Param #
===========================================================================
├─UnetrPPEncoder: 1-1 --
| └─ModuleList: 2-1 --
| | └─Sequential: 3-1 --
| | | └─Convolution: 4-1 --
| | | | └─Conv3d: 5-1 1,024
| | | └─GroupNorm: 4-2 64
| | └─Sequential: 3-2 --
| | | └─Convolution: 4-3 --
| | | | └─Conv3d: 5-2 16,384
| | | └─GroupNorm: 4-4 128
| | └─Sequential: 3-3 --
| | | └─Convolution: 4-5 --
| | | | └─Conv3d: 5-3 65,536
| | | └─GroupNorm: 4-6 256
| | └─Sequential: 3-4 --
| | | └─Convolution: 4-7 --
| | | | └─Conv3d: 5-4 262,144
| | | └─GroupNorm: 4-8 512
| └─ModuleList: 2-2 --
| | └─Sequential: 3-5 --ia
| | | └─TransformerBlock: 4-9 --
| | | | └─LayerNorm: 5-5 64
| | | | └─EPA: 5-6 --
| | | | | └─Linear: 6-1 4,096
| | | | | └─Linear: 6-2 2,097,216
| | | | | └─Dropout: 6-3 --
| | | | | └─Dropout: 6-4 --
| | | | | └─Linear: 6-5 528
| | | | | └─Linear: 6-6 528
| | | | └─UnetResBlock: 5-7 --
| | | | | └─Convolution: 6-7 --
| | | | | | └─Conv3d: 7-1 27,648
| | | | | └─Convolution: 6-8 --
| | | | | | └─Conv3d: 7-2 27,648
| | | | | └─LeakyReLU: 6-9 --
| | | | | └─BatchNorm3d: 6-10 64
| | | | | └─BatchNorm3d: 6-11 64
| | | | └─Sequential: 5-8 --
| | | | | └─Dropout3d: 6-12 --
| | | | | └─Conv3d: 6-13 1,056
| | | └─TransformerBlock: 4-10 --
| | | | └─LayerNorm: 5-9 64
| | | | └─EPA: 5-10 --
| | | | | └─Linear: 6-14 4,096
| | | | | └─Linear: 6-15 2,097,216
| | | | | └─Dropout: 6-16 --
| | | | | └─Dropout: 6-17 --
| | | | | └─Linear: 6-18 528
| | | | | └─Linear: 6-19 528
| | | | └─UnetResBlock: 5-11 --
| | | | | └─Convolution: 6-20 --
| | | | | | └─Conv3d: 7-3 27,648
| | | | | └─Convolution: 6-21 --
| | | | | | └─Conv3d: 7-4 27,648
| | | | | └─LeakyReLU: 6-22 --
| | | | | └─BatchNorm3d: 6-23 64
| | | | | └─BatchNorm3d: 6-24 64
| | | | └─Sequential: 5-12 --
| | | | | └─Dropout3d: 6-25 --
| | | | | └─Conv3d: 6-26 1,056
| | | └─TransformerBlock: 4-11 --
| | | | └─LayerNorm: 5-13 64
| | | | └─EPA: 5-14 --
| | | | | └─Linear: 6-27 4,096
| | | | | └─Linear: 6-28 2,097,216
| | | | | └─Dropout: 6-29 --
| | | | | └─Dropout: 6-30 --
| | | | | └─Linear: 6-31 528
| | | | | └─Linear: 6-32 528
| | | | └─UnetResBlock: 5-15 --
| | | | | └─Convolution: 6-33 --
| | | | | | └─Conv3d: 7-5 27,648
| | | | | └─Convolution: 6-34 --
| | | | | | └─Conv3d: 7-6 27,648
| | | | | └─LeakyReLU: 6-35 --
| | | | | └─BatchNorm3d: 6-36 64
| | | | | └─BatchNorm3d: 6-37 64
| | | | └─Sequential: 5-16 --
| | | | | └─Dropout3d: 6-38 --
| | | | | └─Conv3d: 6-39 1,056
| | └─Sequential: 3-6 --
| | | └─TransformerBlock: 4-12 --
| | | | └─LayerNorm: 5-17 128
| | | | └─EPA: 5-18 --
| | | | | └─Linear: 6-40 16,384
| | | | | └─Linear: 6-41 262,208
| | | | | └─Dropout: 6-42 --
| | | | | └─Dropout: 6-43 --
| | | | | └─Linear: 6-44 2,080
| | | | | └─Linear: 6-45 2,080
| | | | └─UnetResBlock: 5-19 --
| | | | | └─Convolution: 6-46 --
| | | | | | └─Conv3d: 7-7 110,592
| | | | | └─Convolution: 6-47 --
| | | | | | └─Conv3d: 7-8 110,592
| | | | | └─LeakyReLU: 6-48 --
| | | | | └─BatchNorm3d: 6-49 128
| | | | | └─BatchNorm3d: 6-50 128
| | | | └─Sequential: 5-20 --
| | | | | └─Dropout3d: 6-51 --
| | | | | └─Conv3d: 6-52 4,160
| | | └─TransformerBlock: 4-13 --
| | | | └─LayerNorm: 5-21 128
| | | | └─EPA: 5-22 --
| | | | | └─Linear: 6-53 16,384
| | | | | └─Linear: 6-54 262,208
| | | | | └─Dropout: 6-55 --
| | | | | └─Dropout: 6-56 --
| | | | | └─Linear: 6-57 2,080
| | | | | └─Linear: 6-58 2,080
| | | | └─UnetResBlock: 5-23 --
| | | | | └─Convolution: 6-59 --
| | | | | | └─Conv3d: 7-9 110,592
| | | | | └─Convolution: 6-60 --
| | | | | | └─Conv3d: 7-10 110,592
| | | | | └─LeakyReLU: 6-61 --
| | | | | └─BatchNorm3d: 6-62 128
| | | | | └─BatchNorm3d: 6-63 128
| | | | └─Sequential: 5-24 --
| | | | | └─Dropout3d: 6-64 --
| | | | | └─Conv3d: 6-65 4,160
| | | └─TransformerBlock: 4-14 --
| | | | └─LayerNorm: 5-25 128
| | | | └─EPA: 5-26 --
| | | | | └─Linear: 6-66 16,384
| | | | | └─Linear: 6-67 262,208
| | | | | └─Dropout: 6-68 --
| | | | | └─Dropout: 6-69 --
| | | | | └─Linear: 6-70 2,080
| | | | | └─Linear: 6-71 2,080
| | | | └─UnetResBlock: 5-27 --
| | | | | └─Convolution: 6-72 --
| | | | | | └─Conv3d: 7-11 110,592
| | | | | └─Convolution: 6-73 --
| | | | | | └─Conv3d: 7-12 110,592
| | | | | └─LeakyReLU: 6-74 --
| | | | | └─BatchNorm3d: 6-75 128
| | | | | └─BatchNorm3d: 6-76 128
| | | | └─Sequential: 5-28 --
| | | | | └─Dropout3d: 6-77 --
| | | | | └─Conv3d: 6-78 4,160
| | └─Sequential: 3-7 --
| | | └─TransformerBlock: 4-15 --
| | | | └─LayerNorm: 5-29 256
| | | | └─EPA: 5-30 --
| | | | | └─Linear: 6-79 65,536
| | | | | └─Linear: 6-80 32,832
| | | | | └─Dropout: 6-81 --
| | | | | └─Dropout: 6-82 --
| | | | | └─Linear: 6-83 8,256
| | | | | └─Linear: 6-84 8,256
| | | | └─UnetResBlock: 5-31 --
| | | | | └─Convolution: 6-85 --
| | | | | | └─Conv3d: 7-13 442,368
| | | | | └─Convolution: 6-86 --
| | | | | | └─Conv3d: 7-14 442,368
| | | | | └─LeakyReLU: 6-87 --
| | | | | └─BatchNorm3d: 6-88 256
| | | | | └─BatchNorm3d: 6-89 256
| | | | └─Sequential: 5-32 --
| | | | | └─Dropout3d: 6-90 --
| | | | | └─Conv3d: 6-91 16,512
| | | └─TransformerBlock: 4-16 --
| | | | └─LayerNorm: 5-33 256
| | | | └─EPA: 5-34 --
| | | | | └─Linear: 6-92 65,536
| | | | | └─Linear: 6-93 32,832
| | | | | └─Dropout: 6-94 --
| | | | | └─Dropout: 6-95 --
| | | | | └─Linear: 6-96 8,256
| | | | | └─Linear: 6-97 8,256
| | | | └─UnetResBlock: 5-35 --
| | | | | └─Convolution: 6-98 --
| | | | | | └─Conv3d: 7-15 442,368
| | | | | └─Convolution: 6-99 --
| | | | | | └─Conv3d: 7-16 442,368
| | | | | └─LeakyReLU: 6-100 --
| | | | | └─BatchNorm3d: 6-101 256
| | | | | └─BatchNorm3d: 6-102 256
| | | | └─Sequential: 5-36 --
| | | | | └─Dropout3d: 6-103 --
| | | | | └─Conv3d: 6-104 16,512
| | | └─TransformerBlock: 4-17 --
| | | | └─LayerNorm: 5-37 256
| | | | └─EPA: 5-38 --
| | | | | └─Linear: 6-105 65,536
| | | | | └─Linear: 6-106 32,832
| | | | | └─Dropout: 6-107 --
| | | | | └─Dropout: 6-108 --
| | | | | └─Linear: 6-109 8,256
| | | | | └─Linear: 6-110 8,256
| | | | └─UnetResBlock: 5-39 --
| | | | | └─Convolution: 6-111 --
| | | | | | └─Conv3d: 7-17 442,368
| | | | | └─Convolution: 6-112 --
| | | | | | └─Conv3d: 7-18 442,368
| | | | | └─LeakyReLU: 6-113 --
| | | | | └─BatchNorm3d: 6-114 256
| | | | | └─BatchNorm3d: 6-115 256
| | | | └─Sequential: 5-40 --
| | | | | └─Dropout3d: 6-116 --
| | | | | └─Conv3d: 6-117 16,512
| | └─Sequential: 3-8 --
| | | └─TransformerBlock: 4-18 --
| | | | └─LayerNorm: 5-41 512
| | | | └─EPA: 5-42 --
| | | | | └─Linear: 6-118 262,144
| | | | | └─Linear: 6-119 2,080
| | | | | └─Dropout: 6-120 --
| | | | | └─Dropout: 6-121 --
| | | | | └─Linear: 6-122 32,896
| | | | | └─Linear: 6-123 32,896
| | | | └─UnetResBlock: 5-43 --
| | | | | └─Convolution: 6-124 --
| | | | | | └─Conv3d: 7-19 1,769,472
| | | | | └─Convolution: 6-125 --
| | | | | | └─Conv3d: 7-20 1,769,472
| | | | | └─LeakyReLU: 6-126 --
| | | | | └─BatchNorm3d: 6-127 512
| | | | | └─BatchNorm3d: 6-128 512
| | | | └─Sequential: 5-44 --
| | | | | └─Dropout3d: 6-129 --
| | | | | └─Conv3d: 6-130 65,792
| | | └─TransformerBlock: 4-19 --
| | | | └─LayerNorm: 5-45 512
| | | | └─EPA: 5-46 --
| | | | | └─Linear: 6-131 262,144
| | | | | └─Linear: 6-132 2,080
| | | | | └─Dropout: 6-133 --
| | | | | └─Dropout: 6-134 --
| | | | | └─Linear: 6-135 32,896
| | | | | └─Linear: 6-136 32,896
| | | | └─UnetResBlock: 5-47 --
| | | | | └─Convolution: 6-137 --
| | | | | | └─Conv3d: 7-21 1,769,472
| | | | | └─Convolution: 6-138 --
| | | | | | └─Conv3d: 7-22 1,769,472
| | | | | └─LeakyReLU: 6-139 --
| | | | | └─BatchNorm3d: 6-140 512
| | | | | └─BatchNorm3d: 6-141 512
| | | | └─Sequential: 5-48 --
| | | | | └─Dropout3d: 6-142 --
| | | | | └─Conv3d: 6-143 65,792
| | | └─TransformerBlock: 4-20 --
| | | | └─LayerNorm: 5-49 512
| | | | └─EPA: 5-50 --
| | | | | └─Linear: 6-144 262,144
| | | | | └─Linear: 6-145 2,080
| | | | | └─Dropout: 6-146 --
| | | | | └─Dropout: 6-147 --
| | | | | └─Linear: 6-148 32,896
| | | | | └─Linear: 6-149 32,896
| | | | └─UnetResBlock: 5-51 --
| | | | | └─Convolution: 6-150 --
| | | | | | └─Conv3d: 7-23 1,769,472
| | | | | └─Convolution: 6-151 --
| | | | | | └─Conv3d: 7-24 1,769,472
| | | | | └─LeakyReLU: 6-152 --
| | | | | └─BatchNorm3d: 6-153 512
| | | | | └─BatchNorm3d: 6-154 512
| | | | └─Sequential: 5-52 --
| | | | | └─Dropout3d: 6-155 --
| | | | | └─Conv3d: 6-156 65,792
├─UnetResBlock: 1-2 --
| └─Convolution: 2-3 --
| | └─Conv3d: 3-9 432
| └─Convolution: 2-4 --
| | └─Conv3d: 3-10 6,912
| └─LeakyReLU: 2-5 --
| └─InstanceNorm3d: 2-6 --
| └─InstanceNorm3d: 2-7 --
| └─Convolution: 2-8 --
| | └─Conv3d: 3-11 16
| └─InstanceNorm3d: 2-9 --
├─UnetrUpBlock: 1-3 --
| └─Convolution: 2-10 --
| | └─ConvTranspose3d: 3-12 262,144
| └─ModuleList: 2-11 --
| | └─Sequential: 3-13 --
| | | └─TransformerBlock: 4-21 --
| | | | └─LayerNorm: 5-53 256
| | | | └─EPA: 5-54 --
| | | | | └─Linear: 6-157 65,536
| | | | | └─Linear: 6-158 32,832
| | | | | └─Dropout: 6-159 --
| | | | | └─Dropout: 6-160 --
| | | | | └─Linear: 6-161 8,256
| | | | | └─Linear: 6-162 8,256
| | | | └─UnetResBlock: 5-55 --
| | | | | └─Convolution: 6-163 --
| | | | | | └─Conv3d: 7-25 442,368
| | | | | └─Convolution: 6-164 --
| | | | | | └─Conv3d: 7-26 442,368
| | | | | └─LeakyReLU: 6-165 --
| | | | | └─BatchNorm3d: 6-166 256
| | | | | └─BatchNorm3d: 6-167 256
| | | | └─Sequential: 5-56 --
| | | | | └─Dropout3d: 6-168 --
| | | | | └─Conv3d: 6-169 16,512
| | | └─TransformerBlock: 4-22 --
| | | | └─LayerNorm: 5-57 256
| | | | └─EPA: 5-58 --
| | | | | └─Linear: 6-170 65,536
| | | | | └─Linear: 6-171 32,832
| | | | | └─Dropout: 6-172 --
| | | | | └─Dropout: 6-173 --
| | | | | └─Linear: 6-174 8,256
| | | | | └─Linear: 6-175 8,256
| | | | └─UnetResBlock: 5-59 --
| | | | | └─Convolution: 6-176 --
| | | | | | └─Conv3d: 7-27 442,368
| | | | | └─Convolution: 6-177 --
| | | | | | └─Conv3d: 7-28 442,368
| | | | | └─LeakyReLU: 6-178 --
| | | | | └─BatchNorm3d: 6-179 256
| | | | | └─BatchNorm3d: 6-180 256
| | | | └─Sequential: 5-60 --
| | | | | └─Dropout3d: 6-181 --
| | | | | └─Conv3d: 6-182 16,512
| | | └─TransformerBlock: 4-23 --
| | | | └─LayerNorm: 5-61 256
| | | | └─EPA: 5-62 --
| | | | | └─Linear: 6-183 65,536
| | | | | └─Linear: 6-184 32,832
| | | | | └─Dropout: 6-185 --
| | | | | └─Dropout: 6-186 --
| | | | | └─Linear: 6-187 8,256
| | | | | └─Linear: 6-188 8,256
| | | | └─UnetResBlock: 5-63 --
| | | | | └─Convolution: 6-189 --
| | | | | | └─Conv3d: 7-29 442,368
| | | | | └─Convolution: 6-190 --
| | | | | | └─Conv3d: 7-30 442,368
| | | | | └─LeakyReLU: 6-191 --
| | | | | └─BatchNorm3d: 6-192 256
| | | | | └─BatchNorm3d: 6-193 256
| | | | └─Sequential: 5-64 --
| | | | | └─Dropout3d: 6-194 --
| | | | | └─Conv3d: 6-195 16,512
├─UnetrUpBlock: 1-4 --
| └─Convolution: 2-12 --
| | └─ConvTranspose3d: 3-14 65,536
| └─ModuleList: 2-13 --
| | └─Sequential: 3-15 --
| | | └─TransformerBlock: 4-24 --
| | | | └─LayerNorm: 5-65 128
| | | | └─EPA: 5-66 --
| | | | | └─Linear: 6-196 16,384
| | | | | └─Linear: 6-197 262,208
| | | | | └─Dropout: 6-198 --
| | | | | └─Dropout: 6-199 --
| | | | | └─Linear: 6-200 2,080
| | | | | └─Linear: 6-201 2,080
| | | | └─UnetResBlock: 5-67 --
| | | | | └─Convolution: 6-202 --
| | | | | | └─Conv3d: 7-31 110,592
| | | | | └─Convolution: 6-203 --
| | | | | | └─Conv3d: 7-32 110,592
| | | | | └─LeakyReLU: 6-204 --
| | | | | └─BatchNorm3d: 6-205 128
| | | | | └─BatchNorm3d: 6-206 128
| | | | └─Sequential: 5-68 --
| | | | | └─Dropout3d: 6-207 --
| | | | | └─Conv3d: 6-208 4,160
| | | └─TransformerBlock: 4-25 --
| | | | └─LayerNorm: 5-69 128
| | | | └─EPA: 5-70 --
| | | | | └─Linear: 6-209 16,384
| | | | | └─Linear: 6-210 262,208
| | | | | └─Dropout: 6-211 --
| | | | | └─Dropout: 6-212 --
| | | | | └─Linear: 6-213 2,080
| | | | | └─Linear: 6-214 2,080
| | | | └─UnetResBlock: 5-71 --
| | | | | └─Convolution: 6-215 --
| | | | | | └─Conv3d: 7-33 110,592
| | | | | └─Convolution: 6-216 --
| | | | | | └─Conv3d: 7-34 110,592
| | | | | └─LeakyReLU: 6-217 --
| | | | | └─BatchNorm3d: 6-218 128
| | | | | └─BatchNorm3d: 6-219 128
| | | | └─Sequential: 5-72 --
| | | | | └─Dropout3d: 6-220 --
| | | | | └─Conv3d: 6-221 4,160
| | | └─TransformerBlock: 4-26 --
| | | | └─LayerNorm: 5-73 128
| | | | └─EPA: 5-74 --
| | | | | └─Linear: 6-222 16,384
| | | | | └─Linear: 6-223 262,208
| | | | | └─Dropout: 6-224 --
| | | | | └─Dropout: 6-225 --
| | | | | └─Linear: 6-226 2,080
| | | | | └─Linear: 6-227 2,080
| | | | └─UnetResBlock: 5-75 --
| | | | | └─Convolution: 6-228 --
| | | | | | └─Conv3d: 7-35 110,592
| | | | | └─Convolution: 6-229 --
| | | | | | └─Conv3d: 7-36 110,592
| | | | | └─LeakyReLU: 6-230 --
| | | | | └─BatchNorm3d: 6-231 128
| | | | | └─BatchNorm3d: 6-232 128
| | | | └─Sequential: 5-76 --
| | | | | └─Dropout3d: 6-233 --
| | | | | └─Conv3d: 6-234 4,160
├─UnetrUpBlock: 1-5 --
| └─Convolution: 2-14 --
| | └─ConvTranspose3d: 3-16 16,384
| └─ModuleList: 2-15 --
| | └─Sequential: 3-17 --
| | | └─TransformerBlock: 4-27 --
| | | | └─LayerNorm: 5-77 64
| | | | └─EPA: 5-78 --
| | | | | └─Linear: 6-235 4,096
| | | | | └─Linear: 6-236 2,097,216
| | | | | └─Dropout: 6-237 --
| | | | | └─Dropout: 6-238 --
| | | | | └─Linear: 6-239 528
| | | | | └─Linear: 6-240 528
| | | | └─UnetResBlock: 5-79 --
| | | | | └─Convolution: 6-241 --
| | | | | | └─Conv3d: 7-37 27,648
| | | | | └─Convolution: 6-242 --
| | | | | | └─Conv3d: 7-38 27,648
| | | | | └─LeakyReLU: 6-243 --
| | | | | └─BatchNorm3d: 6-244 64
| | | | | └─BatchNorm3d: 6-245 64
| | | | └─Sequential: 5-80 --
| | | | | └─Dropout3d: 6-246 --
| | | | | └─Conv3d: 6-247 1,056
| | | └─TransformerBlock: 4-28 --
| | | | └─LayerNorm: 5-81 64
| | | | └─EPA: 5-82 --
| | | | | └─Linear: 6-248 4,096
| | | | | └─Linear: 6-249 2,097,216
| | | | | └─Dropout: 6-250 --
| | | | | └─Dropout: 6-251 --
| | | | | └─Linear: 6-252 528
| | | | | └─Linear: 6-253 528
| | | | └─UnetResBlock: 5-83 --
| | | | | └─Convolution: 6-254 --
| | | | | | └─Conv3d: 7-39 27,648
| | | | | └─Convolution: 6-255 --
| | | | | | └─Conv3d: 7-40 27,648
| | | | | └─LeakyReLU: 6-256 --
| | | | | └─BatchNorm3d: 6-257 64
| | | | | └─BatchNorm3d: 6-258 64
| | | | └─Sequential: 5-84 --
| | | | | └─Dropout3d: 6-259 --
| | | | | └─Conv3d: 6-260 1,056
| | | └─TransformerBlock: 4-29 --
| | | | └─LayerNorm: 5-85 64
| | | | └─EPA: 5-86 --
| | | | | └─Linear: 6-261 4,096
| | | | | └─Linear: 6-262 2,097,216
| | | | | └─Dropout: 6-263 --
| | | | | └─Dropout: 6-264 --
| | | | | └─Linear: 6-265 528
| | | | | └─Linear: 6-266 528
| | | | └─UnetResBlock: 5-87 --
| | | | | └─Convolution: 6-267 --
| | | | | | └─Conv3d: 7-41 27,648
| | | | | └─Convolution: 6-268 --
| | | | | | └─Conv3d: 7-42 27,648
| | | | | └─LeakyReLU: 6-269 --
| | | | | └─BatchNorm3d: 6-270 64
| | | | | └─BatchNorm3d: 6-271 64
| | | | └─Sequential: 5-88 --
| | | | | └─Dropout3d: 6-272 --
| | | | | └─Conv3d: 6-273 1,056
├─UnetrUpBlock: 1-6 --
| └─Convolution: 2-16 --
| | └─ConvTranspose3d: 3-18 16,384
| └─ModuleList: 2-17 --
| | └─UnetResBlock: 3-19 --
| | | └─Convolution: 4-30 --
| | | | └─Conv3d: 5-89 6,912
| | | └─Convolution: 4-31 --
| | | | └─Conv3d: 5-90 6,912
| | | └─LeakyReLU: 4-32 --
| | | └─InstanceNorm3d: 4-33 --
| | | └─InstanceNorm3d: 4-34 --
├─UnetOutBlock: 1-7 --
| └─Convolution: 2-18 --
| | └─Conv3d: 3-20 238
├─UnetOutBlock: 1-8 --
| └─Convolution: 2-19 --
| | └─Conv3d: 3-21 462
├─UnetOutBlock: 1-9 --
| └─Convolution: 2-20 --
| | └─Conv3d: 3-22 910
===========================================================================
Total params: 34,643,882
Trainable params: 34,643,882
Non-trainable params: 0
===========================================================================
'''
模型代码部分搭建
模型类
class UNETR_PP(SegmentationNetwork):
"""
UNETR++ based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: [64, 128, 128],
feature_size: int = 16,
hidden_size: int = 256,
num_heads: int = 4,
pos_embed: str = "perceptron", # TODO: Remove the argument
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
depths=None,
dims=None,
conv_op=nn.Conv3d,
do_ds=True,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimension of the last encoder.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
dropout_rate: faction of the input units to drop.
depths: number of blocks for each stage.
dims: number of channel maps for the stages.
conv_op: type of convolution operation.
do_ds: use deep supervision to compute the loss.
Examples::
# for single channel input 4-channel output with patch size of (64, 128, 128), feature size of 16, batch
norm and depths of [3, 3, 3, 3] with output channels [32, 64, 128, 256], 4 heads, and 14 classes with
deep supervision:
# >>> net = UNETR_PP(in_channels=1, out_channels=14, img_size=(64, 128, 128), feature_size=16, num_heads=4,
# >>> norm_name='batch', depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True)
"""
super().__init__()
if depths is None:
depths = [3, 3, 3, 3]
self.do_ds = do_ds
self.conv_op = conv_op
self.num_classes = out_channels
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.patch_size = (2, 4, 4)
self.feat_size = (
img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages
img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages
img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages
)
self.hidden_size = hidden_size
self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)
self.encoder1 = UnetResBlock(
spatial_dims=3,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=8 * 8 * 8,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=16 * 16 * 16,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=32 * 32 * 32,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=(2, 4, 4),
norm_name=norm_name,
out_size=64 * 128 * 128,
conv_decoder=True,
)
self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
if self.do_ds:
self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def forward(self, x_in):
x_output, hidden_states = self.unetr_pp_encoder(x_in)
convBlock = self.encoder1(x_in)
# Four encoders
enc1 = hidden_states[0]
enc2 = hidden_states[1]
enc3 = hidden_states[2]
enc4 = hidden_states[3]
# Four decoders
dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc3)
dec2 = self.decoder4(dec3, enc2)
dec1 = self.decoder3(dec2, enc1)
out = self.decoder2(dec1, convBlock)
if self.do_ds:
logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
else:
logits = self.out1(out)
return logits