Unet复现
时间: 2025-04-27 14:28:37 浏览: 24
### UNet 模型的复现
#### PyTorch 实现
为了在 PyTorch 中构建和训练 UNet 模型,可以按照以下方式定义网络结构:
```python
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor)
self.up2 = Up(512, 256 // factor)
self.up3 = Up(256, 128 // factor)
self.up4 = Up(128, 64)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
```
此代码片段展示了如何创建一个基本版本的 UNet 架构,在医学影像分析等领域广泛应用。
对于 TensorFlow 的实现,则有另一种风格的方法[^1]:
#### TensorFlow 实现
下面是一个简单的基于 Keras API 的 UNet 定义:
```python
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
def unet(input_size=(256, 256, 1)):
inputs = Input(input_size)
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, (2, 2), activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
merge6 = concatenate([drop4,up6], axis=3)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(merge6)
up7 = Conv2D(256, (2, 2), activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
merge7 = concatenate([conv3,up7], axis=3)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge7)
up8 = Conv2D(128, (2, 2), activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))
merge8 = concatenate([conv2,up8], axis=3)
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge8)
up9 = Conv2D(64, (2, 2), activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))
merge9 = concatenate([conv1,up9], axis=3)
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge9)
conv9 = Conv2D(2, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
return model
```
这段代码同样实现了经典的 U 形编码解码器架构,并利用了跳跃连接来保留空间信息。
阅读全文
相关推荐


















