UNet裂缝分割代码
时间: 2025-05-15 08:04:09 浏览: 19
### UNet 裂缝分割代码实现
以下是基于UNet架构的裂缝分割代码实现,适用于处理不同类型的裂缝图像(如纯裂缝、类似裂缝、带苔藓的裂缝等)。此代码实现了模型训练过程,并提供了如何加载数据集以及保存模型的方法。
#### 数据准备与预处理
为了确保模型能够有效学习到裂缝特征,需对输入数据进行标准化和增强操作。以下是一个简单的数据预处理流程:
```python
import numpy as np
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
def preprocess_data(image_path, mask_path):
transform = Compose([
Resize((256, 256)), # 将图片调整为固定大小
ToTensor(), # 转换为张量
Normalize(mean=[0.485], std=[0.229]) # 归一化
])
image = Image.open(image_path).convert('L') # 加载灰度图
mask = Image.open(mask_path).convert('L') # 加载掩码
return transform(image), transform(mask)
```
#### UNet 模型定义
下面展示了UNet网络结构的一个简化版本,适合用于二分类任务(即裂缝 vs 非裂缝)。
```python
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(UNet, self).__init__()
self.down1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.middle = DoubleConv(128, 256)
self.upconv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.up1 = DoubleConv(256, 128)
self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.up2 = DoubleConv(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
down1_out = self.down1(x)
pool1_out = self.pool1(down1_out)
down2_out = self.down2(pool1_out)
pool2_out = self.pool2(down2_out)
middle_out = self.middle(pool2_out)
upconv1_out = self.upconv1(middle_out)
concat1 = torch.cat([upconv1_out, down2_out], dim=1)
up1_out = self.up1(concat1)
upconv2_out = self.upconv2(up1_out)
concat2 = torch.cat([upconv2_out, down1_out], dim=1)
up2_out = self.up2(concat2)
final_out = self.final_conv(up2_out)
return final_out
```
#### 训练脚本
以下是一段完整的训练代码片段,假设已经准备好数据集路径 `PATH_TO_THE_DATASET_FOLDER` 和模型存储路径 `PATH_TO_MODEL_DIRECTORY`。
```python
import os
import argparse
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from sklearn.model_selection import train_test_split
import torch.optim as optim
from tqdm import tqdm
# 定义自定义数据集类
class CrackDataset(DatasetFolder):
def __init__(self, root, transform=None):
super(CrackDataset, self).__init__(root=root, loader=lambda path: preprocess_data(path.replace('_mask', ''), path)[0],
extensions=['_mask.png'], transform=transform)
def __getitem__(self, index):
img, _ = super().__getitem__(index)
mask_path = self.samples[index][0].replace('.png', '_mask.png')
_, mask = preprocess_data('', mask_path)
return img, mask
# 参数解析器
parser = argparse.ArgumentParser()
parser.add_argument('-data_dir', type=str, required=True, help='Path to the dataset folder.')
parser.add_argument('-model_dir', type=str, required=True, help='Path to save model checkpoints.')
parser.add_argument('-model_type', type=str, default='resnet_101', help='Type of backbone network.')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型并设置优化器
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# 创建数据加载器
dataset = CrackDataset(root=args.data_dir)
train_set, val_set = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
# 开始训练循环
for epoch in range(10): # 运行10个epoch作为示例
model.train()
running_loss = 0.0
for images, masks in tqdm(train_loader):
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
torch.save(model.state_dict(), os.path.join(args.model_dir, f'model_{epoch}.pth'))
```
上述代码涵盖了从数据预处理到模型训练的过程[^1]。
---
阅读全文
相关推荐


















