图像分割模型代码
时间: 2025-05-04 18:53:05 浏览: 29
### 图像分割模型代码示例
#### PyTorch 实现 U-Net 模型
U-Net 是一种经典的图像分割模型,在医学图像领域应用广泛。以下是基于 PyTorch 的 U-Net 模型实现:
```python
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.up1 = Up(256, 128)
self.up2 = Up(128, 64)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
logits = self.outc(x)
return logits
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
```
上述代码实现了 U-Net 结构,适用于二分类或多分类的语义分割任务[^3]。
---
#### TensorFlow/Keras 实现 U-Net 模型
对于 TensorFlow 用户,可以使用 Keras API 来快速搭建 U-Net 模型:
```python
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
def unet(input_shape=(None, None, 1)):
inputs = Input(input_shape)
# Encoder (Contracting Path)
c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
p1 = MaxPooling2D((2, 2))(c1)
c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
p2 = MaxPooling2D((2, 2))(c2)
# Decoder (Expansive Path)
u1 = UpSampling2D((2, 2))(p2)
u1 = concatenate([u1, c2])
c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(u1)
c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(c3)
u2 = UpSampling2D((2, 2))(c3)
u2 = concatenate([u2, c1])
c4 = Conv2D(64, (3, 3), activation='relu', padding='same')(u2)
c4 = Conv2D(64, (3, 3), activation='relu', padding='same')(c4)
outputs = Conv2D(1, (1, 1), activation='sigmoid')(c4)
model = Model(inputs=[inputs], outputs=[outputs])
return model
```
此代码片段展示了如何通过卷积层、池化层以及上采样操作来构建 U-Net 模型[^1]。
---
#### Vision Transformer 和 CNN 融合的 UNETR
如果希望探索更先进的架构,比如融合 Vision Transformer (ViT) 和传统 CNN 的 UNETR,则需要手动设计编码器部分。由于目前缺乏公开的 Tensorflow 版本代码,可以通过阅读论文并参考现有 PyTorch 实现来进行开发[^2]。
以下是一个简单的 ViT 编码器框架(仅作示意):
```python
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization, Dense, MultiHeadAttention
class PatchEmbedding(tf.keras.layers.Layer):
def __init__(self, patch_size, embed_dim):
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.projection = tf.keras.layers.Dense(embed_dim)
def call(self, images):
patches = tf.image.extract_patches(images, sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1], padding="VALID")
batch_size, num_patches_x, num_patches_y, channels = patches.shape
patches = tf.reshape(patches, [-1, num_patches_x * num_patches_y, channels])
embeddings = self.projection(patches)
return embeddings
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout_rate=0.1):
super(TransformerBlock, self).__init__()
self.attention_norm = LayerNormalization()
self.mlp_norm = LayerNormalization()
self.multihead_attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.dense1 = Dense(int(mlp_ratio * embed_dim))
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.dense2 = Dense(embed_dim)
def call(self, x):
h = self.attention_norm(x)
attention_output = self.multihead_attention(h, h)
x += attention_output
h = self.mlp_norm(x)
dense_output = self.dense1(h)
dense_output = self.dropout(dense_output)
dense_output = self.dense2(dense_output)
x += dense_output
return x
```
以上代码提供了 ViT 部分的核心组件,可作为 UNETR 中编码器的基础结构。
---
### 数据存储与管理
为了高效训练大型图像数据集,建议使用 HDF5 文件格式进行 minibatch 存储和读取。这种方法能够显著减少内存占用并提高 I/O 效率。
---
阅读全文
相关推荐

















