u-net 模型训练

博客介绍了使用数据增强imgaug库,包含该库的安装方法以及训练代码,与深度学习相关,借助Python实现数据增强操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用数据增强 imgaug库

安装imgaug:

pip install imgaug

训练代码:

import torch
from torch.utils.data import Dataset
import numpy as np

import glob
# 数据增强相关包
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from x1 import UNet


class SegmentDataset(Dataset):
    def __init__(self, where1='image_', where2='mask_', seq=None):
        # 获取numpy文件数据
        # 图片列表
        self.img_list = glob.glob('dui/{}/*'.format(where1))
        self.mask_list = glob.glob('dui/{}/*'.format(where2))
        # 数据增强的处理流程
        self.seq = seq

    def __len__(self):
        # 获取数据集大小
        return len(self.img_list)

    def __getitem__(self, idx):
        # 获取具体某一个数据

        # 获取图片文件名
        img_file = self.img_list[idx]
        mask_list = self.mask_list[idx]
        # 获取标注文件名
        # mask_file = img_file.replace('img', 'label')

        # 加载数据
        img = np.load(img_file)
        mask = np.load(mask_list)

        # 数据增强处理
        if self.seq:
            segmap = SegmentationMapsOnImage(mask, shape=mask.shape)
            img, mask = self.seq(image=img, segmentation_maps=segmap)
            # 获取数组内容
            mask = mask.get_arr()

        # 扩张维度变成张量
        return np.expand_dims(img, 0), np.expand_dims(mask, 0)


# 数据增强的处理流程
seq = iaa.Sequential([
    iaa.Affine(
        scale=(0.8, 1.2),  # 缩放
        rotate=(-45, 45)  # 旋转
    ),
    iaa.ElasticTransformation()  # 弹性形变
])

# 使用dataloader加载数据
batch_size = 8
num_workers = 0

train_dataset = SegmentDataset(where1='img', where2='mask', seq=seq)
test_dataset = SegmentDataset(where1='img', where2='mask', seq=None)

# dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 模型实例化
model = UNet()
# # 断点续训 加载以前的权重参数
model.load_state_dict(torch.load("saved_model/unet_course_best.pt"), strict=True)
model.to(device)

loss_fn = torch.nn.BCEWithLogitsLoss()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
# 动态减少学习率
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2)
# # 使用tensorboard可视化
# from torch.utils.tensorboard import SummaryWriter
#
# writer = SummaryWriter(log_dir='./log')
import time


# 计算测试集的loss
def check_test_loss(loader, model):
    # 记录loss
    loss = 0
    # 不记录梯度
    with torch.no_grad():
        # 遍历测试数据
        for i, (x, y) in enumerate(loader):
            # 获取图像
            x = x.to(device, dtype=torch.float32)
            # 获取标注
            y = y.to(device, dtype=torch.float32)

            # 获取预测值
            y_pred = model(x)

            # 计算loss
            loss_batch = loss_fn(y_pred, y)

            # 累加
            loss += loss_batch

    loss = loss / len(loader)
    return loss


# 开始训练
EPOCH_NUM = 200
# 记录最小的测试loss
best_test_loss = 100

for epoch in range(EPOCH_NUM):
    # 获取每一批次图像信息
    # 计算整批数据的loss
    loss = 0
    # 记录一个epoch运行的时间
    start_time = time.time()
    for i, (x, y) in enumerate(train_loader):
        # 每次update清空梯度
        model.zero_grad()
        # 获取图像
        x = x.to(device, dtype=torch.float32)
        # 获取标注
        y = y.to(device, dtype=torch.float32)

        # 获取预测值
        y_pred = model(x)

        # 计算loss
        loss_batch = loss_fn(y_pred, y)

        # 计算梯度
        loss_batch.backward()
        optimizer.step()
        optimizer.zero_grad()

        # 获取每个batch的训练loss
        loss_batch = loss_batch.detach().cpu()
        # print
        print(loss_batch)
        loss += loss_batch

    # 计算loss
    loss = loss / len(train_loader)
    # 降低LR:如果loss连续10个epoch不再下降,则降低LR
    scheduler.step(loss)

    # 计算测试集loss
    test_loss = check_test_loss(test_loader, model)

    # # 记录到tensorboard可视化
    # writer.add_scalar('LOSS/train', loss, epoch)
    # writer.add_scalar('LOSS/test', test_loss, epoch)

    # 保存最佳模型
    if best_test_loss > test_loss:
        # 赋值
        best_test_loss = test_loss
        # 保存模型
        torch.save(model.state_dict(), 'saved_model/unet_course_best.pt')
        # 输出信息
        print('第{}个EPOCH达到最低的测试LOSS'.format(epoch))

    print('第{}个epoch执行时间{}s,train loss为{},test loss 为{}'.format(
        epoch,
        time.time() - start_time,
        loss,
        test_loss
    ))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

默执_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值