本文是对ICELUT技术的代码讲解,原文解读请看ICELUT文章讲解。
1、原文概要
ICELUT实现了一个全LUT的轻量化方案。整体流程如下所示。
在这个方案中,可以看到很多轻量化LUT设计的影子。输入会被分为MSB和LSB,使用两个不同的特征提取器处理后使用加法融合,后续通过一个分组的FC层获取到基础LUT的权重,加权的LUT对原图处理便可以得到最终结果。
2、代码结构
代码整体结构如下:
3 、核心代码模块
model_ICELUT.py
文件
1. ICELUT类
这里是网络的整体实现,其定义了TrilinearInterpolation、backbone、SplitFC、CLUT。前向中给出了计算过程,首先图像经过backbone计算中间结果,中间结果分5组分别得到weights,然后相加得到合并后的weights,weights送入CLUT中推理得到最终使用的3DLUT。
class ICELUT(nn.Module):
def __init__(self, nsw, dim=33, backbone='PointWise', *args, **kwargs):
super().__init__()
print('dump')
self.TrilinearInterpolation = TrilinearInterpolation()
print('dump')
self.pre = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.backbone = eval(backbone)()
last_channel = self.backbone.last_channel
self.classifier1 = nn.Sequential(
nn.Conv2d(last_channel//5, 512,1,1),
nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(512, int(nsw[:2]),1,1),
)
self.classifier2 = nn.Sequential(
nn.Conv2d(last_channel//5, 512,1,1),
nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(512, int(nsw[:2]),1,1),
)
self.classifier3 = nn.Sequential(
nn.Conv2d(last_channel//5, 512,1,1),
nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(512, int(nsw[:2]),1,1),
)
self.classifier4 = nn.Sequential(
nn.Conv2d(last_channel//5, 512,1,1),
nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(512, int(nsw[:2]),1,1),
)
self.classifier5 = nn.Sequential(
nn.Conv2d(last_channel//5, 512,1,1),
nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(512, int(nsw[:2]),1,1),
)
nsw = nsw.split("+")
num, s, w = int(nsw[0]), int(nsw[1]), int(nsw[2])
self.CLUTs = CLUT(num, dim, s, w)
def fuse_basis_to_one(self, img, img_s, TVMN=None):
mid_results = self.backbone(img, img_s)
mid1 = mid_results[:,:2,:,:]
mid2 = mid_results[:,2:4,:,:]
mid3 = mid_results[:,4:6,:,:]
mid4 = mid_results[:,6:8,:,:]
mid5 = mid_results[:,8:,:,:]
weights1 = self.classifier1(mid1)[:,:,0,0] # n, num
weights2 = self.classifier2(mid2)[:,:,0,0] # n, num
weights3 = self.classifier3(mid3)[:,:,0,0] # n, num
weights4 = self.classifier4(mid4)[:,:,0,0] # n, num
weights5 = self.classifier5(mid5)[:,:,0,0] # n, num
weights = weights1 + weights2 + weights3 + weights4 + weights5
D3LUT, tvmn_loss = self.CLUTs(weights, TVMN)
return D3LUT, tvmn_loss
def forward(self, img_msb, img_lsb, img_org, TVMN=None):
D3LUT, tvmn_loss = self.fuse_basis_to_one(img_msb, img_lsb, TVMN)
img_res = self.TrilinearInterpolation(D3LUT, img_org)
return {
"fakes": img_res + img_org,
"3DLUT": D3LUT,
"tvmn_loss": tvmn_loss,
}
def fuse_basis_to_one_test(self, img, img_s, TVMN=None):
mid_results = self.backbone.forward_test(img, img_s)
mid1 = mid_results[:,:2,:,:]
mid2 = mid_results[:,2:4,:,:]
mid3 = mid_results[:,4:6,:,:]
mid4 = mid_results[:,6:8,:,:]
mid5 = mid_results[:,8:,:,:]
weights1 = self.classifier1(mid1)[:,:,0,0] # n, num
weights1 = torch.clamp(torch.round(weights1*4)/4, -32, 31.75)
weights2 = self.classifier2(mid2)[:,:,0,0] # n, num
weights2 = torch.clamp(torch.round(weights2*4)/4, -32, 31.75)
weights3 = self.classifier3(mid3)[:,:,0,0] # n, num
weights3 = torch.clamp(torch.round(weights3*4)/4, -32, 31.75)
weights4 = self.classifier4(mid4)[:,:,0,0] # n, num
weights4 = torch.clamp(torch.round(weights4*4)/4, -32, 31.75)
weights5 = self.classifier5(mid5)[:,:,0,0] # n, num
weights5 = torch.clamp(torch.round(weights5*4)/4, -32, 31.75)
weights = weights1 + weights2 + weights3 + weights4 + weights5
D3LUT, tvmn_loss = self.CLUTs(weights, TVMN)
return D3LUT, tvmn_loss
def forward_test(self, img_msb, img_lsb, img_org, TVMN=None):
D3LUT, tvmn_loss = self.fuse_basis_to_one_test(img_msb, img_lsb, TVMN)
img_res = self.TrilinearInterpolation(D3LUT, img_org)
return {
"fakes": img_res + img_org,
"3DLUT": D3LUT,
"tvmn_loss": tvmn_loss,
}
2. PointWise类
是点卷积网络的实现,有MSB和LSB两个输入,最后输出他们的融合结果。
class PointWise(torch.nn.Module):
def __init__(self, upscale=4):
super(PointWise, self).__init__()
self.upscale = upscale
self.conv1 = nn.Conv2d(3, 32, 1, stride=1, padding=0, dilation=1)
self.conv2 = nn.Conv2d(32, 64, 1, stride=1, padding=0, dilation=1)
self.conv3 = nn.Conv2d(64, 128, 1, stride=1, padding=0, dilation=1)
self.conv4 = nn.Conv2d(128, 256, 1, stride=1, padding=0, dilation=1)
self.conv5 = nn.Conv2d(256, 512, 1, stride=1, padding=0, dilation=1)
self.conv6 = nn.Conv2d(512, 256, 1, stride=1, padding=0, dilation=1)
self.conv7 = nn.Conv2d(256, 128, 1, stride=1, padding=0, dilation=1)
self.conv8 = nn.Conv2d(128, 64, 1, stride=1, padding=0, dilation=1)
self.conv9 = nn.Conv2d(64, 32, 1, stride=1, padding=0, dilation=1)
self.conv10 = nn.Conv2d(32, 10, 1, stride=1, padding=0, dilation=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv1_s = nn.Conv2d(3, 32, 1, stride=1, padding=0, dilation=1)
self.conv2_s = nn.Conv2d(32, 64, 1, stride=1, padding=0, dilation=1)
self.conv3_s = nn.Conv2d(64, 128, 1, stride=1, padding=0, dilation=1)
self.conv4_s = nn.Conv2d(128, 256, 1, stride=1, padding=0, dilation=1)
self.conv5_s = nn.Conv2d(256, 512, 1, stride=1, padding=0, dilation=1)
self.conv6_s = nn.Conv2d(512, 256, 1, stride=1, padding=0, dilation=1)
self.conv7_s = nn.Conv2d(256, 128, 1, stride=1, padding=0, dilation=1)
self.conv8_s = nn.Conv2d(128, 64, 1, stride=1, padding=0, dilation=1)
self.conv9_s = nn.Conv2d(64, 32, 1, stride=1, padding=0, dilation=1)
self.conv10_s = nn.Conv2d(32, 10, 1, stride=1, padding=0, dilation=1)
self.last_channel = 10
def forward(self, x_in, x_s):
B, C, H, W = x_in.size()
x = self.conv1(x_in)
x = self.conv2(F.relu(x))
x = self.conv3(F.relu(x))
x = self.conv4(F.relu(x))
x = self.conv5(F.relu(x))
x = self.conv6(F.relu(x))
x = self.conv7(F.relu(x))
x = self.conv8(F.relu(x))
x = self.conv9(F.relu(x))
x = self.conv10(F.relu(x))
out_i = x
x = self.conv1_s(x_s)
x = self.conv2_s(F.relu(x))
x = self.conv3_s(F.relu(x))
x = self.conv4_s(F.relu(x))
x = self.conv5_s(F.relu(x))
x = self.conv6_s(F.relu(x))
x = self.conv7_s(F.relu(x))
x = self.conv8_s(F.relu(x))
x = self.conv9_s(F.relu(x))
x = self.conv10_s(F.relu(x))
out_s = x
out = out_s + out_i
out = self.pool(out)
# out = torch.clamp(torch.round(out*2)/2, -16, 15.5)
return out
def forward_test(self, x_in, x_s):
B, C, H, W = x_in.size()
x = self.conv1(x_in)
x = self.conv2(F.relu(x))
x = self.conv3(F.relu(x))
x = self.conv4(F.relu(x))
x = self.conv5(F.relu(x))
x = self.conv6(F.relu(x))
x = self.conv7(F.relu(x))
x = self.conv8(F.relu(x))
x = self.conv9(F.relu(x))
x = self.conv10(F.relu(x))
out_i = x
# out_i = torch.clamp(torch.round(out_i*4)/4, -32, 31.75)
x = self.conv1_s(x_s)
x = self.conv2_s(F.relu(x))
x = self.conv3_s(F.relu(x))
x = self.conv4_s(F.relu(x))
x = self.conv5_s(F.relu(x))
x = self.conv6_s(F.relu(x))
x = self.conv7_s(F.relu(x))
x = self.conv8_s(F.relu(x))
x = self.conv9_s(F.relu(x))
x = self.conv10_s(F.relu(x))
out_s = x
# out_s = torch.clamp(torch.round(out_s*4)/4, -32, 31.75)
out = out_s + out_i
out = self.pool(out)
out = torch.clamp(torch.round(out*2)/2, -16, 15.5)
# print(out)
return out
3. TrilinearInterpolation类
三次插值模块,实际的实现在trilinear_cpp/src/trilinear.cpp和trilinear_kernel.cu中,看CPU的实现是常见的三次插值代码。
void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
{
const int output_size = height * width;;
int index = 0;
for (index = 0; index < output_size; ++index)
{
float r = image[index];
float g = image[index + width * height];
float b = image[index + width * height * 2];
int r_id = floor(r / binsize);
int g_id = floor(g / binsize);
int b_id = floor(b / binsize);
float r_d = fmod(r,binsize) / binsize;
float g_d = fmod(g,binsize) / binsize;
float b_d = fmod(b,binsize) / binsize;
int id000 = r_id + g_id * dim + b_id * dim * dim;
int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;
float w000 = (1-r_d)*(1-g_d)*(1-b_d);
float w100 = r_d*(1-g_d)*(1-b_d);
float w010 = (1-r_d)*g_d*(1-b_d);
float w110 = r_d*g_d*(1-b_d);
float w001 = (1-r_d)*(1-g_d)*b_d;
float w101 = r_d*(1-g_d)*b_d;
float w011 = (1-r_d)*g_d*b_d;
float w111 = r_d*g_d*b_d;
output[index] = w000 * lut[id000] + w100 * lut[id100] +
w010 * lut[id010] + w110 * lut[id110] +
w001 * lut[id001] + w101 * lut[id101] +
w011 * lut[id011] + w111 * lut[id111];
output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] +
w010 * lut[id010 + shift] + w110 * lut[id110 + shift] +
w001 * lut[id001 + shift] + w101 * lut[id101 + shift] +
w011 * lut[id011 + shift] + w111 * lut[id111 + shift];
output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] +
w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] +
w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] +
w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
}
}
4. CLUT类
此在CLUT代码讲解中有讲到,作者直接复用了这部分代码。
transfer2LUT.py
文件
文章讲解中提到关于backbone的部分,作者使用fp32保存,关于splitFC的部分作者使用int8保存,关于他们的转换代码如下:
# ========================== Channel LUT ==============================
model_G = PointWise().cuda()
ckpt = model_G.state_dict()
# print(model_G.state_dict().keys())
lm = torch.load('{}'.format(MODEL_PATH))
model2_dict = model_G.state_dict()
state_dict = {k.replace('backbone.', ''):v for k,v in lm.items() if k.replace('backbone.', '') in model2_dict.keys()}
model2_dict.update(state_dict)
model_G.load_state_dict(model2_dict)
# if model_G.keys() not in
model_G.load_state_dict(model_G.state_dict(), strict=True)
### Extract input-output pairs
with torch.no_grad():
model_G.eval()
SAMPLING_INTERVAL = 4
# 1D input
base = torch.arange(0, 256, 2**SAMPLING_INTERVAL) # 0-256
base[-1] -= 1
L = base.size(0)
# 2D input
first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1) # 256*256 0 0 0... |1 1 1... |...|255 255 255...
second = base.cuda().repeat(L) # 256*256 0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
onebytwo = torch.stack([first, second], 1) # [256*256, 2]
# 3D input
third = base.cuda().unsqueeze(1).repeat(1, L*L).reshape(-1) # 256*256*256 0 x65536|1 x65536|...|255 x65536
onebytwo = onebytwo.repeat(L, 1)
onebythree = torch.cat([third.unsqueeze(1), onebytwo], 1) # [256*256*256, 3]
input_tensor = onebythree.unsqueeze(1).unsqueeze(1).reshape(-1,3,1,1).float() / 256.0
out_m, out_l = model_G(input_tensor, input_tensor)
results_m = out_m.cpu().data.numpy().astype(np.float32)
results_l = out_l.cpu().data.numpy().astype(np.float32)
np.save("./ICELUT/Model_lsb_fp32", results_l)
np.save("./ICELUT/Model_msb_fp32", results_m)
## ======================== Weight LUT ===============================
model_G = SplitFC().cuda()
ckpt = model_G.state_dict()
lm = torch.load('{}'.format(MODEL_PATH))
model2_dict = model_G.state_dict()
state_dict = {k:v for k,v in lm.items() if k in model2_dict.keys()}
model2_dict.update(state_dict)
# print(state_dict['conv1.weight'])
model_G.load_state_dict(model2_dict)
# print(model_G.state_dict()['classifier1.2.bias'])
# print(model_G.state_dict()['conv1.weight'])
# lm = torch.load('{}'.format(MODEL_PATH))
# print(lm['backbone.conv1.weight'])
# print(lm['classifier1.2.bias'])
with torch.no_grad():
model_G.eval()
base = torch.arange(-32, 32, 1) / 2.0 # -16 ~ 15.5
L = base.size(0)
first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1) # 256*256 0 0 0... |1 1 1... |...|255 255 255...
second = base.cuda().repeat(L) # 256*256 0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
onebytwo = torch.stack([first, second], 1) # [64*64, 2]
input_tensor = onebytwo.unsqueeze(-1).unsqueeze(-1).reshape(-1,2,1,1)
input_tensor = input_tensor.repeat(1,5,1,1)
# print(input_tensor)
out = model_G(input_tensor)
print(out.shape)
results_m = out.cpu().data.numpy()
results_m = (results_m * 4 + 32).astype(np.int8)
np.save("./ICELUT/classifier_int8", results_m)
ckpt = torch.load('/home/ysd21/ICELUT_camera/FiveK/ICELUT_10+05+10/model0270.pth')
dim, num, s, w = 33, 10, 5, 10
s_Layers = nn.Parameter(torch.rand(dim, s)/5-0.1)
w_Layers = nn.Parameter(torch.rand(w, dim*dim)/5-0.1)
LUTs = nn.Parameter(torch.zeros(s*num*3,w))
s_Layers.data = ckpt['CLUTs.s_Layers']
w_Layers.data = ckpt['CLUTs.w_Layers']
LUTs.data = ckpt['CLUTs.LUTs']
dict = {'s':s_Layers.data,'w':w_Layers.data,'Basis_LUT':LUTs.data}
np.save('./ICELUT/Basis_lut.npy', dict)
最后关于CLUT的部分,作者直接保存的矩阵的值,也就是说这部分还是需要计算,并不是查找的方式。
inference_demo.py
文件
推理的代码,使用的model是转换为LUT的ICELUT模型。
for img in img_ls:
## load the parallel MSB and LSB input
input_path = os.path.join(inp_dir, img)
img_input = cv2.cvtColor(cv2.imread(input_path, -1), cv2.COLOR_BGR2RGB)
img_input_msb = TF.to_tensor((img_input//16)/16)
img_input_lsb = TF.to_tensor((img_input%16)/16)
img_input = TF.to_tensor(img_input)
## inference using the shape (32, 32)
img_input_resize_msb, img_input_resize_lsb = TF.resize(img_input_msb, (scale, scale)),TF.resize(img_input_lsb, (scale, scale))
img_input_resize_msb = torch.round(img_input_resize_msb*16) / 16
img_input_resize_lsb = torch.round(img_input_resize_lsb*16) / 16
input_msb = img_input_resize_msb.type(torch.FloatTensor).unsqueeze(0).cuda()
input_lsb = img_input_resize_lsb.type(torch.FloatTensor).unsqueeze(0).cuda()
input_org = img_input.type(torch.FloatTensor).unsqueeze(0).cuda()
## inference
output = model(input_msb, input_lsb, input_org)
## image saving
img_ls = [input_org.squeeze().data, output['out'].squeeze().data]
if img_ls[0].shape[0] > 3:
img_ls = [img.permute(2,0,1) for img in img_ls]
save_image(img_ls, join(save_path, f"{img}.jpg"), nrow=len(img_ls))
这个转换为LUT的model对应的类的实现如下:
class ICELUT(nn.Module):
def __init__(self, device, nsw='10+05+10', dim=33, backbone='SRNet'):
super().__init__()
self.TrilinearInterpolation = TrilinearInterpolation()
self.feature = FeatLUT(device)
self.classifier = torch.from_numpy(np.load('./ICELUT/classifier_int8.npy'))
self.cluster = torch.split(self.classifier, 64**2, dim=0)
self.lut_cls = []
for i in self.cluster:
self.lut_cls.append(i.unsqueeze(0).to(device))
self.lut_cat = torch.cat(self.lut_cls,0)
nsw = nsw.split("+")
num, s, w = int(nsw[0]), int(nsw[1]), int(nsw[2])
self.CLUTs = CLUT(num, device, dim, s, w)
self.device = device
def fuse_basis_to_one(self, msb, lsb, TVMN=None):
mid_results = self.feature(msb, lsb)
mid_results = ((mid_results*2).long()+32).reshape(5,2)
index = mid_results[:,0]*64 + mid_results[:,1]
output = self.lut_cat[torch.arange(5), index.squeeze()]
output = (output - 32) / 4.0
weights = torch.sum(output,0).unsqueeze(0).to(self.device)
D3LUT = self.CLUTs(weights, TVMN)
return D3LUT
def forward(self, img_msb, img_lsb, img_org, TVMN=None):
D3LUT = self.fuse_basis_to_one(img_msb, img_lsb, TVMN)
img_res = self.TrilinearInterpolation(D3LUT, img_org)
return {
"out": img_res + img_org,
"3DLUT": D3LUT,
}
backbone的部分对应于FeatLUT,实现上跟卷积也是一样的,只不过是将LUT替换了网络计算。最后结果约束在5bit内。lut_cat对应于原来的SFC层,最后也是经过CLUT得到计算所需的3DLUT,并增强原图像。
3、总结
代码实现核心的部分讲解完毕,ICELUT作为一个全LUT实现的轻量化模型,在转换为LUT之后确实在使用时非常方便。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。