【LUT技术专题】ICELUT代码讲解

本文是对ICELUT技术的代码讲解,原文解读请看ICELUT文章讲解

1、原文概要

ICELUT实现了一个全LUT的轻量化方案。整体流程如下所示。
在这里插入图片描述

在这个方案中,可以看到很多轻量化LUT设计的影子。输入会被分为MSB和LSB,使用两个不同的特征提取器处理后使用加法融合,后续通过一个分组的FC层获取到基础LUT的权重,加权的LUT对原图处理便可以得到最终结果。

2、代码结构

代码整体结构如下
在这里插入图片描述

3 、核心代码模块

model_ICELUT.py 文件

1. ICELUT类

这里是网络的整体实现,其定义了TrilinearInterpolation、backbone、SplitFC、CLUT。前向中给出了计算过程,首先图像经过backbone计算中间结果,中间结果分5组分别得到weights,然后相加得到合并后的weights,weights送入CLUT中推理得到最终使用的3DLUT。

class ICELUT(nn.Module): 
    def __init__(self, nsw, dim=33, backbone='PointWise', *args, **kwargs):
        super().__init__()
        print('dump')
        self.TrilinearInterpolation = TrilinearInterpolation()
        print('dump')
        self.pre = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.backbone = eval(backbone)()
        last_channel = self.backbone.last_channel
        self.classifier1 = nn.Sequential(
                nn.Conv2d(last_channel//5, 512,1,1),
                nn.Hardswish(inplace=True),
                # nn.Dropout(p=0.2, inplace=True),
                nn.Conv2d(512, int(nsw[:2]),1,1),
        )

        self.classifier2 = nn.Sequential(
                nn.Conv2d(last_channel//5, 512,1,1),
                nn.Hardswish(inplace=True),
                # nn.Dropout(p=0.2, inplace=True),
                nn.Conv2d(512, int(nsw[:2]),1,1),
        )
        self.classifier3 = nn.Sequential(
                nn.Conv2d(last_channel//5, 512,1,1),
                nn.Hardswish(inplace=True),
                # nn.Dropout(p=0.2, inplace=True),
                nn.Conv2d(512, int(nsw[:2]),1,1),
        )
        self.classifier4 = nn.Sequential(
                nn.Conv2d(last_channel//5, 512,1,1),
                nn.Hardswish(inplace=True),
                # nn.Dropout(p=0.2, inplace=True),
                nn.Conv2d(512, int(nsw[:2]),1,1),
        )
        self.classifier5 = nn.Sequential(
                nn.Conv2d(last_channel//5, 512,1,1),
                nn.Hardswish(inplace=True),
                # nn.Dropout(p=0.2, inplace=True),
                nn.Conv2d(512, int(nsw[:2]),1,1),
        )
        nsw = nsw.split("+")
        num, s, w = int(nsw[0]), int(nsw[1]), int(nsw[2])
        self.CLUTs = CLUT(num, dim, s, w)

    def fuse_basis_to_one(self, img, img_s, TVMN=None):
        mid_results = self.backbone(img, img_s)
        mid1 = mid_results[:,:2,:,:]
        mid2 = mid_results[:,2:4,:,:]
        mid3 = mid_results[:,4:6,:,:]
        mid4 = mid_results[:,6:8,:,:]
        mid5 = mid_results[:,8:,:,:]

        weights1 = self.classifier1(mid1)[:,:,0,0] # n, num
        weights2 = self.classifier2(mid2)[:,:,0,0] # n, num
        weights3 = self.classifier3(mid3)[:,:,0,0] # n, num
        weights4 = self.classifier4(mid4)[:,:,0,0] # n, num
        weights5 = self.classifier5(mid5)[:,:,0,0] # n, num
        weights = weights1 + weights2 + weights3 + weights4 + weights5
        D3LUT, tvmn_loss = self.CLUTs(weights, TVMN)
        return D3LUT, tvmn_loss    

    def forward(self, img_msb, img_lsb, img_org, TVMN=None):
        D3LUT, tvmn_loss = self.fuse_basis_to_one(img_msb, img_lsb, TVMN)
        img_res = self.TrilinearInterpolation(D3LUT, img_org)
        return {
            "fakes": img_res + img_org,
            "3DLUT": D3LUT,
            "tvmn_loss": tvmn_loss,
        }
    
    def fuse_basis_to_one_test(self, img, img_s, TVMN=None):
        mid_results = self.backbone.forward_test(img, img_s)
        mid1 = mid_results[:,:2,:,:]
        mid2 = mid_results[:,2:4,:,:]
        mid3 = mid_results[:,4:6,:,:]
        mid4 = mid_results[:,6:8,:,:]
        mid5 = mid_results[:,8:,:,:]

        weights1 = self.classifier1(mid1)[:,:,0,0] # n, num
        weights1 = torch.clamp(torch.round(weights1*4)/4, -32, 31.75)
        weights2 = self.classifier2(mid2)[:,:,0,0] # n, num
        weights2 = torch.clamp(torch.round(weights2*4)/4, -32, 31.75)
        weights3 = self.classifier3(mid3)[:,:,0,0] # n, num
        weights3 = torch.clamp(torch.round(weights3*4)/4, -32, 31.75)
        weights4 = self.classifier4(mid4)[:,:,0,0] # n, num
        weights4 = torch.clamp(torch.round(weights4*4)/4, -32, 31.75)
        weights5 = self.classifier5(mid5)[:,:,0,0] # n, num
        weights5 = torch.clamp(torch.round(weights5*4)/4, -32, 31.75)
        weights = weights1 + weights2 + weights3 + weights4 + weights5
        D3LUT, tvmn_loss = self.CLUTs(weights, TVMN)
        return D3LUT, tvmn_loss  

    def forward_test(self, img_msb, img_lsb, img_org, TVMN=None):
        D3LUT, tvmn_loss = self.fuse_basis_to_one_test(img_msb, img_lsb, TVMN)
        img_res = self.TrilinearInterpolation(D3LUT, img_org)
        return {
            "fakes": img_res + img_org,
            "3DLUT": D3LUT,
            "tvmn_loss": tvmn_loss,
        }
2. PointWise类

是点卷积网络的实现,有MSB和LSB两个输入,最后输出他们的融合结果。

class PointWise(torch.nn.Module):
    def __init__(self, upscale=4):
        super(PointWise, self).__init__()

        self.upscale = upscale

        self.conv1 = nn.Conv2d(3, 32, 1, stride=1, padding=0, dilation=1)
        self.conv2 = nn.Conv2d(32, 64, 1, stride=1, padding=0, dilation=1)
        self.conv3 = nn.Conv2d(64, 128, 1, stride=1, padding=0, dilation=1)
        self.conv4 = nn.Conv2d(128, 256, 1, stride=1, padding=0, dilation=1)
        self.conv5 = nn.Conv2d(256, 512, 1, stride=1, padding=0, dilation=1)
        self.conv6 = nn.Conv2d(512, 256, 1, stride=1, padding=0, dilation=1)
        self.conv7 = nn.Conv2d(256, 128, 1, stride=1, padding=0, dilation=1)
        self.conv8 = nn.Conv2d(128, 64, 1, stride=1, padding=0, dilation=1)
        self.conv9 = nn.Conv2d(64, 32, 1, stride=1, padding=0, dilation=1)
        self.conv10 = nn.Conv2d(32, 10, 1, stride=1, padding=0, dilation=1)
        self.pool = nn.AdaptiveAvgPool2d(1)

        self.conv1_s = nn.Conv2d(3, 32, 1, stride=1, padding=0, dilation=1)
        self.conv2_s = nn.Conv2d(32, 64, 1, stride=1, padding=0, dilation=1)
        self.conv3_s = nn.Conv2d(64, 128, 1, stride=1, padding=0, dilation=1)
        self.conv4_s = nn.Conv2d(128, 256, 1, stride=1, padding=0, dilation=1)
        self.conv5_s = nn.Conv2d(256, 512, 1, stride=1, padding=0, dilation=1)
        self.conv6_s = nn.Conv2d(512, 256, 1, stride=1, padding=0, dilation=1)
        self.conv7_s = nn.Conv2d(256, 128, 1, stride=1, padding=0, dilation=1)
        self.conv8_s = nn.Conv2d(128, 64, 1, stride=1, padding=0, dilation=1)
        self.conv9_s = nn.Conv2d(64, 32, 1, stride=1, padding=0, dilation=1)
        self.conv10_s = nn.Conv2d(32, 10, 1, stride=1, padding=0, dilation=1)
        self.last_channel = 10


    def forward(self, x_in, x_s):
        B, C, H, W = x_in.size()

        x = self.conv1(x_in)
        x = self.conv2(F.relu(x))
        x = self.conv3(F.relu(x))
        x = self.conv4(F.relu(x))
        x = self.conv5(F.relu(x))
        x = self.conv6(F.relu(x))
        x = self.conv7(F.relu(x))
        x = self.conv8(F.relu(x))
        x = self.conv9(F.relu(x))
        x = self.conv10(F.relu(x))
        out_i = x

        x = self.conv1_s(x_s)
        x = self.conv2_s(F.relu(x))
        x = self.conv3_s(F.relu(x))
        x = self.conv4_s(F.relu(x))
        x = self.conv5_s(F.relu(x))
        x = self.conv6_s(F.relu(x))
        x = self.conv7_s(F.relu(x))
        x = self.conv8_s(F.relu(x))
        x = self.conv9_s(F.relu(x))
        x = self.conv10_s(F.relu(x))
        out_s = x
        out = out_s + out_i
        out = self.pool(out)
        # out = torch.clamp(torch.round(out*2)/2, -16, 15.5)

        return out
    
    def forward_test(self, x_in, x_s):
        B, C, H, W = x_in.size()

        x = self.conv1(x_in)
        x = self.conv2(F.relu(x))
        x = self.conv3(F.relu(x))
        x = self.conv4(F.relu(x))
        x = self.conv5(F.relu(x))
        x = self.conv6(F.relu(x))
        x = self.conv7(F.relu(x))
        x = self.conv8(F.relu(x))
        x = self.conv9(F.relu(x))
        x = self.conv10(F.relu(x))
        out_i = x
        # out_i = torch.clamp(torch.round(out_i*4)/4, -32, 31.75)

        x = self.conv1_s(x_s)
        x = self.conv2_s(F.relu(x))
        x = self.conv3_s(F.relu(x))
        x = self.conv4_s(F.relu(x))
        x = self.conv5_s(F.relu(x))
        x = self.conv6_s(F.relu(x))
        x = self.conv7_s(F.relu(x))
        x = self.conv8_s(F.relu(x))
        x = self.conv9_s(F.relu(x))
        x = self.conv10_s(F.relu(x))
        out_s = x
        # out_s = torch.clamp(torch.round(out_s*4)/4, -32, 31.75)
        out = out_s + out_i
        out = self.pool(out)
        out = torch.clamp(torch.round(out*2)/2, -16, 15.5)
        # print(out)

        return out
3. TrilinearInterpolation类

三次插值模块,实际的实现在trilinear_cpp/src/trilinear.cpp和trilinear_kernel.cu中,看CPU的实现是常见的三次插值代码。

void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
{
    const int output_size = height * width;;

    int index = 0;
    for (index = 0; index < output_size; ++index)
    {
        float r = image[index];
	float g = image[index + width * height];
	float b = image[index + width * height * 2];

	int r_id = floor(r / binsize);
	int g_id = floor(g / binsize);
	int b_id = floor(b / binsize);

	float r_d = fmod(r,binsize) / binsize;
	float g_d = fmod(g,binsize) / binsize;
	float b_d = fmod(b,binsize) / binsize;

	int id000 = r_id + g_id * dim + b_id * dim * dim;
        int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
        int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
        int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
        int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
        int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
        int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
        int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;

	float w000 = (1-r_d)*(1-g_d)*(1-b_d);
        float w100 = r_d*(1-g_d)*(1-b_d);
        float w010 = (1-r_d)*g_d*(1-b_d);
        float w110 = r_d*g_d*(1-b_d);
        float w001 = (1-r_d)*(1-g_d)*b_d;
        float w101 = r_d*(1-g_d)*b_d;
        float w011 = (1-r_d)*g_d*b_d;
        float w111 = r_d*g_d*b_d;

	output[index] = w000 * lut[id000] + w100 * lut[id100] + 
                        w010 * lut[id010] + w110 * lut[id110] + 
                        w001 * lut[id001] + w101 * lut[id101] + 
                        w011 * lut[id011] + w111 * lut[id111];

	output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 
                      		         w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 
                       		         w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 
                               	         w011 * lut[id011 + shift] + w111 * lut[id111 + shift];

	output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 
                           		     w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 
                           		     w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 
                                	     w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
    }
}
4. CLUT类

此在CLUT代码讲解中有讲到,作者直接复用了这部分代码。

transfer2LUT.py 文件

文章讲解中提到关于backbone的部分,作者使用fp32保存,关于splitFC的部分作者使用int8保存,关于他们的转换代码如下:


# ========================== Channel LUT ==============================
model_G = PointWise().cuda()
ckpt = model_G.state_dict()
# print(model_G.state_dict().keys())
lm = torch.load('{}'.format(MODEL_PATH))
model2_dict = model_G.state_dict()
state_dict = {k.replace('backbone.', ''):v for k,v in lm.items() if k.replace('backbone.', '') in model2_dict.keys()}
model2_dict.update(state_dict)
model_G.load_state_dict(model2_dict)

# if model_G.keys() not in
model_G.load_state_dict(model_G.state_dict(), strict=True)

### Extract input-output pairs
with torch.no_grad():
    model_G.eval()
    SAMPLING_INTERVAL = 4
    # 1D input
    base = torch.arange(0, 256, 2**SAMPLING_INTERVAL)   # 0-256
    base[-1] -= 1
    L = base.size(0)

    # 2D input
    first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)  # 256*256   0 0 0...    |1 1 1...     |...|255 255 255...
    second = base.cuda().repeat(L)                             # 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
    onebytwo = torch.stack([first, second], 1)  # [256*256, 2]

    # 3D input
    third = base.cuda().unsqueeze(1).repeat(1, L*L).reshape(-1) # 256*256*256   0 x65536|1 x65536|...|255 x65536
    onebytwo = onebytwo.repeat(L, 1)
    onebythree = torch.cat([third.unsqueeze(1), onebytwo], 1)    # [256*256*256, 3]
    input_tensor = onebythree.unsqueeze(1).unsqueeze(1).reshape(-1,3,1,1).float() / 256.0
    out_m, out_l = model_G(input_tensor, input_tensor)
    results_m = out_m.cpu().data.numpy().astype(np.float32)
    results_l = out_l.cpu().data.numpy().astype(np.float32)
    np.save("./ICELUT/Model_lsb_fp32", results_l)
    np.save("./ICELUT/Model_msb_fp32", results_m)



## ======================== Weight LUT ===============================
model_G = SplitFC().cuda()
ckpt = model_G.state_dict()
lm = torch.load('{}'.format(MODEL_PATH))
model2_dict = model_G.state_dict()

state_dict = {k:v for k,v in lm.items() if k in model2_dict.keys()}
model2_dict.update(state_dict)
# print(state_dict['conv1.weight'])
model_G.load_state_dict(model2_dict)
# print(model_G.state_dict()['classifier1.2.bias'])
# print(model_G.state_dict()['conv1.weight'])
# lm = torch.load('{}'.format(MODEL_PATH))
# print(lm['backbone.conv1.weight'])
# print(lm['classifier1.2.bias'])

with torch.no_grad():
    model_G.eval()
    base = torch.arange(-32, 32, 1) / 2.0   # -16 ~ 15.5
    L = base.size(0)
    first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)  # 256*256   0 0 0...    |1 1 1...     |...|255 255 255...
    second = base.cuda().repeat(L)                             # 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
    onebytwo = torch.stack([first, second], 1)  # [64*64, 2]
    input_tensor = onebytwo.unsqueeze(-1).unsqueeze(-1).reshape(-1,2,1,1)
    input_tensor = input_tensor.repeat(1,5,1,1)
    # print(input_tensor)

    out = model_G(input_tensor)
    print(out.shape)
    results_m = out.cpu().data.numpy()
    results_m = (results_m * 4 + 32).astype(np.int8)
    np.save("./ICELUT/classifier_int8", results_m)

ckpt = torch.load('/home/ysd21/ICELUT_camera/FiveK/ICELUT_10+05+10/model0270.pth')
dim, num, s, w = 33, 10, 5, 10
s_Layers = nn.Parameter(torch.rand(dim, s)/5-0.1)
w_Layers = nn.Parameter(torch.rand(w, dim*dim)/5-0.1)
LUTs = nn.Parameter(torch.zeros(s*num*3,w))
s_Layers.data = ckpt['CLUTs.s_Layers']
w_Layers.data = ckpt['CLUTs.w_Layers']
LUTs.data = ckpt['CLUTs.LUTs']
dict = {'s':s_Layers.data,'w':w_Layers.data,'Basis_LUT':LUTs.data}
np.save('./ICELUT/Basis_lut.npy', dict)

最后关于CLUT的部分,作者直接保存的矩阵的值,也就是说这部分还是需要计算,并不是查找的方式。

inference_demo.py 文件

推理的代码,使用的model是转换为LUT的ICELUT模型。

for img in img_ls:

         ## load the parallel MSB and LSB input
         input_path = os.path.join(inp_dir, img)
         img_input = cv2.cvtColor(cv2.imread(input_path, -1), cv2.COLOR_BGR2RGB)
         img_input_msb = TF.to_tensor((img_input//16)/16)
         img_input_lsb = TF.to_tensor((img_input%16)/16)
         img_input = TF.to_tensor(img_input)

         ## inference using the shape (32, 32)
         img_input_resize_msb, img_input_resize_lsb = TF.resize(img_input_msb, (scale, scale)),TF.resize(img_input_lsb, (scale, scale))
         img_input_resize_msb = torch.round(img_input_resize_msb*16) / 16
         img_input_resize_lsb = torch.round(img_input_resize_lsb*16) / 16
         input_msb = img_input_resize_msb.type(torch.FloatTensor).unsqueeze(0).cuda()
         input_lsb = img_input_resize_lsb.type(torch.FloatTensor).unsqueeze(0).cuda()
         input_org = img_input.type(torch.FloatTensor).unsqueeze(0).cuda()

         ## inference
         output = model(input_msb, input_lsb, input_org)

         ## image saving
         img_ls = [input_org.squeeze().data, output['out'].squeeze().data]
         if img_ls[0].shape[0] > 3:
             img_ls = [img.permute(2,0,1) for img in img_ls]
         save_image(img_ls, join(save_path, f"{img}.jpg"), nrow=len(img_ls))

这个转换为LUT的model对应的类的实现如下:

class ICELUT(nn.Module): 
    def __init__(self, device, nsw='10+05+10', dim=33, backbone='SRNet'):
        super().__init__()
        self.TrilinearInterpolation = TrilinearInterpolation()
        self.feature = FeatLUT(device)
        self.classifier = torch.from_numpy(np.load('./ICELUT/classifier_int8.npy'))
        self.cluster = torch.split(self.classifier, 64**2, dim=0)
        self.lut_cls = []
        for i in self.cluster:
            self.lut_cls.append(i.unsqueeze(0).to(device))
        self.lut_cat = torch.cat(self.lut_cls,0)
            

        nsw = nsw.split("+")
        num, s, w = int(nsw[0]), int(nsw[1]), int(nsw[2])
        self.CLUTs = CLUT(num, device, dim, s, w)
        self.device = device

    def fuse_basis_to_one(self, msb, lsb, TVMN=None):
        mid_results = self.feature(msb, lsb)
        
        mid_results = ((mid_results*2).long()+32).reshape(5,2)
        index = mid_results[:,0]*64 + mid_results[:,1]
        output = self.lut_cat[torch.arange(5), index.squeeze()]
        output = (output - 32) / 4.0 
        weights = torch.sum(output,0).unsqueeze(0).to(self.device)
        D3LUT = self.CLUTs(weights, TVMN)
        return D3LUT

    def forward(self, img_msb, img_lsb, img_org, TVMN=None):
        D3LUT = self.fuse_basis_to_one(img_msb, img_lsb, TVMN)
        img_res = self.TrilinearInterpolation(D3LUT, img_org)
        return {
            "out": img_res + img_org,
            "3DLUT": D3LUT,
        }

backbone的部分对应于FeatLUT,实现上跟卷积也是一样的,只不过是将LUT替换了网络计算。最后结果约束在5bit内。lut_cat对应于原来的SFC层,最后也是经过CLUT得到计算所需的3DLUT,并增强原图像。

3、总结

代码实现核心的部分讲解完毕,ICELUT作为一个全LUT实现的轻量化模型,在转换为LUT之后确实在使用时非常方便。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值