unet做眼底血管图像分割
时间: 2025-06-07 17:39:30 浏览: 25
### 使用 UNet 进行眼底血管图像分割
#### 准备工作
在开始之前,确保安装必要的库并加载所需的数据集。通常情况下,使用 Python 及其科学计算生态系统的工具链是最常见的选择。
```bash
pip install torch torchvision matplotlib numpy scikit-image
```
#### 加载和预处理数据
对于眼底血管图像分割任务来说,准备高质量的训练样本非常重要。常用的数据集包括 DRIVE、STARE 和 CHASE_DB1 等公共可用资源[^3]。这些数据集中包含了标注好的眼底图片及其对应的血管标签图。
```python
import os
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
def load_data(data_dir):
images = []
masks = []
for filename in os.listdir(os.path.join(data_dir, 'images')):
img_path = os.path.join(data_dir, 'images', filename)
mask_path = os.path.join(data_dir, 'masks', filename)
image = io.imread(img_path)
mask = io.imread(mask_path)
# Resize to a fixed size or perform other preprocessing steps here.
resized_image = transform.resize(image, (256, 256))
resized_mask = transform.resize(mask, (256, 256))
images.append(resized_image)
masks.append(resized_mask)
return np.array(images), np.array(masks)
data_directory = './path_to_your_dataset'
X_train, y_train = load_data(data_directory)
plt.imshow(X_train[0])
plt.show()
```
这段代码展示了如何读取文件夹内的图像以及相应的掩码,并调整它们大小以便后续处理。注意这里假设了特定目录结构;实际路径可能有所不同。
#### 构建 UNet 模型架构
UNet 是一种经典的编解码器网络设计,特别适合于医疗影像分析中的语义分割问题。下面给出的是一个简化版 PyTorch 实现:
```python
import torch.nn.functional as F
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.up1 = Up(384, 128)
self.up2 = Up(192, 64)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
logits = self.outc(x)
return logits
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
```
此部分定义了一个标准的 UNet 类,其中包含多个下采样层(`Down`)、上采样层(`Up`),并通过跳跃连接将低层次特征传递给高层次表示,从而更好地保留空间位置信息。
#### 训练过程
一旦完成了上述准备工作之后就可以着手编写训练循环逻辑了。考虑到 GPU 资源有限的情况下建议采用较小批量尺寸进行迭代优化权重参数直至收敛为止。
```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=1).to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.BCEWithLogitsLoss()
for epoch in range(num_epochs): # num_epochs should be defined earlier
running_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data['image'].to(device), data['mask'].to(device)
optimizer.zero_grad()
outputs = model(inputs.float())
loss = loss_fn(outputs.squeeze(dim=1), labels.float())
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'[Epoch {epoch + 1}] Loss: {running_loss / len(train_loader)}')
torch.save(model.state_dict(), "./unet_model.pth")
print("Model has been saved.")
```
以上脚本实现了最基本的训练流程控制——遍历整个批次完成前向传播、反向传播更新操作,并保存最终得到的最佳模型状态字典供以后预测阶段调用。
#### 测试与验证
最后一步是对测试集上的表现做出评价,可以通过可视化一些随机选取出来的实例来直观感受效果好坏程度。
```python
test_img_idx = ... # Choose an
阅读全文
相关推荐

















