暗光增强IAT网络训练

项目地址:https://github.com/cuiziteng/Illumination-Adaptive-Transformer

1.训练暗光增强 train_lol_v2.py 代码一些修改:添加resize,保存指定epoch.pth,loss记录

import torch.nn as nn
import torch.optim
import torch.nn.functional as F
import os
import argparse
from torchvision.models import vgg16
from data_loaders.lol import lowlight_loader
from model.IAT_main import IAT
from IQA_pytorch import SSIM
from utils import PSNR, validation, LossNetwork

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default=0)
parser.add_argument('--img_path', type=str, default="data/LOLv2/train/Low/")
parser.add_argument('--img_val_path', type=str, default="data/LOLv2/val/Low/")
parser.add_argument("--normalize", action="store_false", help="Default Normalize in LOL training.")
parser.add_argument("--img_size_wh", type=tuple, default=(600, 400), help="Resize image in LOL training.")
parser.add_argument('--model_type', type=str, default='s')

parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-4)  # 2e-4
parser.add_argument('--weight_decay', type=float, default=0.0005)
parser.add_argument('--pretrain_dir', type=str, default=None)

parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--save_per_epoch', type=int, default=10)
parser.add_argument('--display_iter', type=int, default=10)
parser.add_argument('--snapshots_folder', type=str, default="workdirs/lol_train2")


if __name__ == '__main__':
    config = parser.parse_args()

    print(config)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)

    if not os.path.exists(config.snapshots_folder):
        os.makedirs(config.snapshots_folder)

    # Model Setting
    model = IAT(type=config.model_type).cuda()
    if config.pretrain_dir is not None:
        model.load_state_dict(torch.load(config.pretrain_dir))

    # Data Setting
    train_dataset = lowlight_loader(images_path=config.img_path, normalize=config.normalize, img_size_wh=config.img_size_wh)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=1,pin_memory=True)
    val_dataset = lowlight_loader(images_path=config.img_val_path, mode='test', normalize=config.normalize)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    # Loss & Optimizer Setting & Metric
    vgg_model = vgg16(pretrained=True).features[:16]
    vgg_model = vgg_model.cuda()

    for param in vgg_model.parameters():
        param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)

    device = next(model.parameters()).device
    print('the device is:', device)

    L1_loss = nn.L1Loss()
    L1_smooth_loss = F.smooth_l1_loss

    loss_network = LossNetwork(vgg_model)
    loss_network.eval()

    ssim = SSIM()
    psnr = PSNR()
    ssim_high = 0
    psnr_high = 0

    model.train()
    print('######## Start IAT Training #########')
    for epoch in range(config.num_epochs):
        print('the epoch is:', epoch)
        for iteration, imgs in enumerate(train_loader):
            low_img, high_img = imgs[0].cuda(), imgs[1].cuda()
            optimizer.zero_grad()
            model.train()
            mul, add, enhance_img = model(low_img)
            loss = L1_smooth_loss(enhance_img, high_img)+0.04*loss_network(enhance_img, high_img)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if ((iteration + 1) % config.display_iter) == 0:
                print("Loss at iteration", iteration + 1, ":", loss.item())
                with open(config.snapshots_folder + '/loss.txt', 'a+') as f:
                    f.write('Epoch_' + str(epoch) + ': ' + 'The iteration_' + str(iteration + 1) + ' loss is: ' + str(loss.item()) + '\n')

        # Evaluation Model
        model.eval()
        PSNR_mean, SSIM_mean = validation(model, val_loader)

        with open(config.snapshots_folder + '/log.txt', 'a+') as f:
            f.write('Epoch_' + str(epoch) + ': ' + 'The SSIM is ' + str(SSIM_mean) + '  , The PSNR is ' + str(PSNR_mean) + '\n')
        if (epoch+1) % config.save_per_epoch == 0:
            torch.save(model.state_dict(), os.path.join(config.snapshots_folder, f"{epoch+1}_Epoch" + '.pth'))
        if SSIM_mean > ssim_high:
            ssim_high = SSIM_mean
            print('the highest SSIM value is:', str(ssim_high))
            torch.save(model.state_dict(), os.path.join(config.snapshots_folder, "best_Epoch" + '.pth'))

        f.close()

2.lol.py文件修改

2.1 添加多个格式的图片

def populate_train_list(images_path, mode='train'):
    # print(images_path)
    image_list_lowlight = []
    img_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']
    for ext in img_extensions:
        image_list_lowlight.extend(glob.glob(os.path.join(images_path, ext)))
    train_list = image_list_lowlight
    if mode == 'train':
        random.shuffle(train_list)

    return train_list

2.2 实现resize

class lowlight_loader(data.Dataset):
    def __init__(self, images_path, mode='train', normalize=True, img_size_wh=None):

        self.size = img_size_wh
        ...


    def get_params(self, low):
        if self.size is None:
            self.w, self.h = low.size
        else:
            if not isinstance(self.size, tuple) or len(self.size) != 2:
                raise ValueError("img_size must be a tuple of two integers (w,h)")
            self.w, self.h = self.size
        ...

        return i, j

2.3 防止灰度图shape不匹配

class lowlight_loader(data.Dataset):
    def __init__(self, images_path, mode='train', normalize=True, img_size_wh=None):
    ...


    def __getitem__(self, index):
        data_lowlight_path = self.data_list[index]

        if self.mode == 'train':
            data_lowlight = Image.open(data_lowlight_path).convert('RGB')
            data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low', 'Normal')).convert('RGB')
        ...


        elif self.mode == 'test':
            data_lowlight = Image.open(data_lowlight_path).convert('RGB')
            data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low', 'Normal')).convert('RGB')
        ...

### YOLOv8IAT 暗光图像增强的技术实现 为了提升YOLOv8在低光环境下的目标检测性能,引入了一种名为 **Image Augmentation Transformer (IAT)** 的轻量级图像增强网络。以下是关于如何实现该技术的具体方法: #### 1. IAT的工作原理 IAT的核心在于通过分解图像信号处理器(ISP)管道到局部和全局图像组件,从而恢复在低光或过/欠曝光条件下的正常光照sRGB图像[^2]。它利用注意力机制来动态调整ISP的相关参数,例如颜色校正矩阵、伽马校正曲线以及亮度增益。 这种设计使得IAT能够有效地增强低照度图像中的细节信息,同时保持较高的计算效率。具体而言,IAT具备以下特点: - 参数数量约为90k,适合嵌入式设备部署。 - 处理时间仅需约0.004秒,满足实时性需求。 #### 2. 集成IAT至YOLOv8的流程 要将IAT集成到YOLOv8中以改善其暗光检测能力,可以按照以下方式操作: ##### (1)创建自定义模块文件 在 `yolov8-main/models` 文件夹下新增一个子目录命名为 `modules`,并将IAT的核心代码保存为Python脚本文件(如 `iat.py`)。此步骤确保了模型架构清晰分离并便于维护[^3]。 ```python # iat.py 示例代码片段 import torch.nn as nn class ImageAugmentationTransformer(nn.Module): def __init__(self, num_channels=3): super(ImageAugmentationTransformer, self).__init__() # 定义IAT层... def forward(self, x): # 实现前向传播逻辑... return enhanced_image ``` ##### (2)修改YOLOv8配置文件 编辑YOLOv8对应的`.yaml`配置文件,在输入阶段加入调用IAT的操作。这一步骤通常涉及在网络主干之前插入预处理部分。 ```yaml # yolov8.yaml 修改示例 backbone: - [-1, 'IA_Transformer', [num_channels]] # 添加IAT作为首个组件 neck: ... head: ... ``` 此处需要注意的是,“IA_Transformer”应指向实际导入路径中的类名或者函数名称;而 `[num_channels]` 则代表传给初始化方法所需的超参列表。 ##### (3)训练与验证过程 完成上述更改之后,继续遵循标准的数据准备规程执行常规训练循环即可。由于IAT已经内置到了整个流水线当中,因此无需额外编写专门针对它的优化器设置项。 #### 3. 性能评估指标 当采用IAT辅助后的YOLOv8应用于黑暗场景测试集合时,预期可以获得更优的结果表现形式体现在以下几个方面: - [email protected] IoU阈值上的显著增长; - 减少误报率(false positives),特别是在背景复杂区域附近; - 增强小物体识别精度,尤其是在极端弱光线条件下。 以上改进均得益于高质量特征提取所带来的优势[^1]。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值