项目地址:https://github.com/cuiziteng/Illumination-Adaptive-Transformer
1.训练暗光增强 train_lol_v2.py 代码一些修改:添加resize,保存指定epoch.pth,loss记录
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
import os
import argparse
from torchvision.models import vgg16
from data_loaders.lol import lowlight_loader
from model.IAT_main import IAT
from IQA_pytorch import SSIM
from utils import PSNR, validation, LossNetwork
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default=0)
parser.add_argument('--img_path', type=str, default="data/LOLv2/train/Low/")
parser.add_argument('--img_val_path', type=str, default="data/LOLv2/val/Low/")
parser.add_argument("--normalize", action="store_false", help="Default Normalize in LOL training.")
parser.add_argument("--img_size_wh", type=tuple, default=(600, 400), help="Resize image in LOL training.")
parser.add_argument('--model_type', type=str, default='s')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-4) # 2e-4
parser.add_argument('--weight_decay', type=float, default=0.0005)
parser.add_argument('--pretrain_dir', type=str, default=None)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--save_per_epoch', type=int, default=10)
parser.add_argument('--display_iter', type=int, default=10)
parser.add_argument('--snapshots_folder', type=str, default="workdirs/lol_train2")
if __name__ == '__main__':
config = parser.parse_args()
print(config)
os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)
if not os.path.exists(config.snapshots_folder):
os.makedirs(config.snapshots_folder)
# Model Setting
model = IAT(type=config.model_type).cuda()
if config.pretrain_dir is not None:
model.load_state_dict(torch.load(config.pretrain_dir))
# Data Setting
train_dataset = lowlight_loader(images_path=config.img_path, normalize=config.normalize, img_size_wh=config.img_size_wh)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=1,pin_memory=True)
val_dataset = lowlight_loader(images_path=config.img_val_path, mode='test', normalize=config.normalize)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
# Loss & Optimizer Setting & Metric
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.cuda()
for param in vgg_model.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)
device = next(model.parameters()).device
print('the device is:', device)
L1_loss = nn.L1Loss()
L1_smooth_loss = F.smooth_l1_loss
loss_network = LossNetwork(vgg_model)
loss_network.eval()
ssim = SSIM()
psnr = PSNR()
ssim_high = 0
psnr_high = 0
model.train()
print('######## Start IAT Training #########')
for epoch in range(config.num_epochs):
print('the epoch is:', epoch)
for iteration, imgs in enumerate(train_loader):
low_img, high_img = imgs[0].cuda(), imgs[1].cuda()
optimizer.zero_grad()
model.train()
mul, add, enhance_img = model(low_img)
loss = L1_smooth_loss(enhance_img, high_img)+0.04*loss_network(enhance_img, high_img)
loss.backward()
optimizer.step()
scheduler.step()
if ((iteration + 1) % config.display_iter) == 0:
print("Loss at iteration", iteration + 1, ":", loss.item())
with open(config.snapshots_folder + '/loss.txt', 'a+') as f:
f.write('Epoch_' + str(epoch) + ': ' + 'The iteration_' + str(iteration + 1) + ' loss is: ' + str(loss.item()) + '\n')
# Evaluation Model
model.eval()
PSNR_mean, SSIM_mean = validation(model, val_loader)
with open(config.snapshots_folder + '/log.txt', 'a+') as f:
f.write('Epoch_' + str(epoch) + ': ' + 'The SSIM is ' + str(SSIM_mean) + ' , The PSNR is ' + str(PSNR_mean) + '\n')
if (epoch+1) % config.save_per_epoch == 0:
torch.save(model.state_dict(), os.path.join(config.snapshots_folder, f"{epoch+1}_Epoch" + '.pth'))
if SSIM_mean > ssim_high:
ssim_high = SSIM_mean
print('the highest SSIM value is:', str(ssim_high))
torch.save(model.state_dict(), os.path.join(config.snapshots_folder, "best_Epoch" + '.pth'))
f.close()
2.lol.py文件修改
2.1 添加多个格式的图片
def populate_train_list(images_path, mode='train'):
# print(images_path)
image_list_lowlight = []
img_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']
for ext in img_extensions:
image_list_lowlight.extend(glob.glob(os.path.join(images_path, ext)))
train_list = image_list_lowlight
if mode == 'train':
random.shuffle(train_list)
return train_list
2.2 实现resize
class lowlight_loader(data.Dataset):
def __init__(self, images_path, mode='train', normalize=True, img_size_wh=None):
self.size = img_size_wh
...
def get_params(self, low):
if self.size is None:
self.w, self.h = low.size
else:
if not isinstance(self.size, tuple) or len(self.size) != 2:
raise ValueError("img_size must be a tuple of two integers (w,h)")
self.w, self.h = self.size
...
return i, j
2.3 防止灰度图shape不匹配
class lowlight_loader(data.Dataset):
def __init__(self, images_path, mode='train', normalize=True, img_size_wh=None):
...
def __getitem__(self, index):
data_lowlight_path = self.data_list[index]
if self.mode == 'train':
data_lowlight = Image.open(data_lowlight_path).convert('RGB')
data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low', 'Normal')).convert('RGB')
...
elif self.mode == 'test':
data_lowlight = Image.open(data_lowlight_path).convert('RGB')
data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low', 'Normal')).convert('RGB')
...