PIV-Liteflownet的相关介绍

在用光流法对PIV粒子图像获取速度场时,其必须了解知道的一个网络就是LiteFlownet.先上传网络结构图,如下图所示:

1.Net c结构:

Netc结构负责对图像特征进行下采样,获取图像特征信息,这里进行六次下采样

 class NetC(torch.nn.Module):
            def __init__(self, num_channels):
                super(NetC, self).__init__()

                # six-module (levels) feature extractor
                self.module_one = torch.nn.Sequential(
                    # 'SAME' padding
                    torch.nn.Conv2d(in_channels=num_channels,
                                    out_channels=32,
                                    kernel_size=7,
                                    stride=1,
                                    padding=3),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

                self.module_two = torch.nn.Sequential(
                    # first conv + relu
                    torch.nn.Conv2d(in_channels=32,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),

                    # second conv + relu
                    torch.nn.Conv2d(in_channels=32,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),

                    # third conv + relu
                    torch.nn.Conv2d(in_channels=32,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

                self.module_three = torch.nn.Sequential(
                    # first conv + relu
                    torch.nn.Conv2d(in_channels=32,
                                    out_channels=64,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),

                    # second conv + relu
                    torch.nn.Conv2d(in_channels=64,
                                    out_channels=64,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

                self.module_four = torch.nn.Sequential(
                    # first conv + relu
                    torch.nn.Conv2d(in_channels=64,
                                    out_channels=96,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),

                    # second conv + relu
                    torch.nn.Conv2d(in_channels=96,
                                    out_channels=96,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

                self.module_five = torch.nn.Sequential(
                    # 'SAME' padding
                    torch.nn.Conv2d(in_channels=96,
                                    out_channels=128,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

                self.module_six = torch.nn.Sequential(
                    # 'SAME' padding
                    torch.nn.Conv2d(in_channels=128,
                                    out_channels=192,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1),
                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
                )

            def forward(self, image_tensor):
                # generate six level features
                level_one_feat = self.module_one(image_tensor)
                level_two_feat = self.module_two(level_one_feat)
                level_three_feat = self.module_three(level_two_feat)
                level_four_feat = self.module_four(level_three_feat)
                level_five_feat = self.module_five(level_four_feat)
                level_six_feat = self.module_six(level_five_feat)

                return [level_one_feat,
                        level_two_feat,
                        level_three_feat,
                        level_four_feat,
                        level_five_feat,
                        level_six_feat]

由于是图像处理中常用的下采样过程,故不进行详细的分析,感兴趣的可以查看以上代码

2.Net E结构:这部分设计很巧妙,嵌入的东西较多,故对此部分进行详细介绍以及代码的讲解

2.1 Matching(匹配单元)

     该部分的主要作用是计算两张图像之间的相关性,其原理与互相关类似:1.backwarp,向后扭曲(将图像2根据流场信息进行扭曲,靠近图像1),PS:为什么是图像2根据流场信息向图像1扭曲,而不是图像1根据流场向图像2扭曲?其主要原因在于:图像1向图像2扭曲,无法保证所有的点都能找到对应的像素点,如若运动速度较大,可能导致出现空洞或者图像边缘出现空白部分

代码如下:

import torch

backwarp_tenGrid = {}

def backwarp(feat_tensor, flow_tensor):
	if str(flow_tensor.size()) not in backwarp_tenGrid:
		horizontal_tensor = torch.linspace(-1.0, 1.0, flow_tensor.shape[3]).view(1, 1, 1, flow_tensor.shape[3]).expand(flow_tensor.shape[0], -1, flow_tensor.shape[2], -1)
		vertical_tensor = torch.linspace(-1.0, 1.0, flow_tensor.shape[2]).view(1, 1, flow_tensor.shape[2], 1).expand(flow_tensor.shape[0], -1, -1, flow_tensor.shape[3])

		backwarp_tenGrid[str(flow_tensor.size())] = torch.cat([ horizontal_tensor, vertical_tensor ], 1).cuda()
	# end

	flow_tensor = torch.cat([ flow_tensor[:, 0:1, :, :] / ((feat_tensor.shape[3] - 1.0) / 2.0), flow_tensor[:, 1:2, :, :] / ((feat_tensor.shape[2] - 1.0) / 2.0) ], 1)

	return torch.nn.functional.grid_sample(input=feat_tensor, grid=(backwarp_tenGrid[str(flow_tensor.size())] + flow_tensor).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)

这里包括:构建标准化坐标系,将像素坐标向标准坐标转换((feat-1)/2)的原因

2.2  计算相关性

 if self.upsample_corr == None:
                    # compute the corelation between feature 1 and warped feature 2
                    corr_tensor = correlation.FunctionCorrelation(tenFirst=feat_tensor_1, tenSecond=feat_tensor_2, intStride=1)
                    corr_tensor = torch.nn.functional.leaky_relu(input=corr_tensor, negative_slope=0.1, inplace=False)
                else:
                    # compute the corelation between feature 1 and warped feature 2
                    corr_tensor = correlation.FunctionCorrelation(tenFirst=feat_tensor_1, tenSecond=feat_tensor_2, intStride=2)
                    corr_tensor = torch.nn.functional.leaky_relu(input=corr_tensor, negative_slope=0.1, inplace=False)
                    corr_tensor = self.upsample_corr(corr_tensor)

                # put correlation into matching CNN
                delta_um = self.matching_cnn(corr_tensor)

这里用的是correlatuon计算相关性,计算完紧接着加入了卷积对相关性进行卷积操作,其主要的作用在于去除高频误差信息,使得更加平滑,有利于噪声剔除。后面相关操作较多。

将匹配完的相关性指标与flow流场信息叠加,进行输出

2.2 subpixel插值

当相邻两帧位移较小时,可以通过subpiexl插值提高精度

同理,里面包括向后扭曲(backwarp)这里就不仔细介绍

2.2.1 构建成本体积

成本体积可以整合不同来源的信息,比如图像差异、光流信息以及深层的卷积特征图,并且一般在成本体积后也会加conv

volume = torch.cat([feat_tensor_1, feat_tensor_2, flow_tensor_scaled], axis=1)
 delta_us = self.subpixel_cnn(volume)

                return flow_tensor + delta_us

2.3 正则化模块

 def forward(self, flow_tensor, image_tensor_1, image_tensor_2, feat_tensor_1, feat_tensor_2):

                # distance between feature 1 and warped feature 2
                flow_tensor_scaled = flow_tensor * self.flow_scale
                feat_tensor_2 = layers.backwarp(feat_tensor_2, flow_tensor_scaled)
                squared_diff_tensor = torch.pow((feat_tensor_1 - feat_tensor_2), 2)
                # sum the difference in both x and y
                squared_diff_tensor = torch.sum(squared_diff_tensor, dim=1, keepdim=True)
                # take the square root
                # diff_tensor = torch.sqrt(squared_diff_tensor)
                diff_tensor = squared_diff_tensor

                # construct volume
                volume_tensor = torch.cat([ diff_tensor, flow_tensor - flow_tensor.view(flow_tensor.shape[0], 2, -1).mean(2, True).view(flow_tensor.shape[0], 2, 1, 1), self.feature_net(feat_tensor_1) ], axis=1)
                dist_tensor = self.regularization_cnn(volume_tensor)
                dist_tensor = self.dist_net(dist_tensor)
                dist_tensor = dist_tensor.pow(2.0).neg()
                dist_tensor = (dist_tensor - dist_tensor.max(1, True)[0]).exp()

                divisor_tensor = dist_tensor.sum(1, True).reciprocal()

                dist_tensor_unfold_x = torch.nn.functional.unfold(input=flow_tensor[:, 0:1, :, :],
                                                                  kernel_size=self.unfold,
                                                                  stride=1,
                                                                  padding=int((self.unfold - 1) / 2)).view_as(dist_tensor)

                dist_tensor_unfold_y = torch.nn.functional.unfold(input=flow_tensor[:, 1:2, :, :],
                                                                  kernel_size=self.unfold,
                                                                  stride=1,
                                                                  padding=int((self.unfold - 1) / 2)).view_as(dist_tensor)


                scale_x_tensor = divisor_tensor * self.scale_x_net(dist_tensor * dist_tensor_unfold_x)
                scale_y_tensor = divisor_tensor * self.scale_y_net(dist_tensor * dist_tensor_unfold_y)


                return torch.cat([scale_x_tensor, scale_y_tensor], axis=1)

正则化模块有几个点需要注意:

1.计算图像1和图像2向后扭曲的差异性。

 # distance between feature 1 and warped feature 2
                flow_tensor_scaled = flow_tensor * self.flow_scale
                feat_tensor_2 = layers.backwarp(feat_tensor_2, flow_tensor_scaled)
                squared_diff_tensor = torch.pow((feat_tensor_1 - feat_tensor_2), 2)
                # sum the difference in both x and y
                squared_diff_tensor = torch.sum(squared_diff_tensor, dim=1, keepdim=True)
                # take the square root
                # diff_tensor = torch.sqrt(squared_diff_tensor)
                diff_tensor = squared_diff_tensor

2.构建成本体积

 volume_tensor = torch.cat([ diff_tensor, flow_tensor - flow_tensor.view(flow_tensor.shape[0], 2, -1).mean(2, True).view(flow_tensor.shape[0], 2, 1, 1), self.feature_net(feat_tensor_1) ], axis=1)

3.torch.nn.functional.unfold的使用,以滑动窗口的形式提取周围特征,有利于局部信息捕捉。

dist_tensor_unfold_x = torch.nn.functional.unfold(input=flow_tensor[:, 0:1, :, :],
                                                                  kernel_size=self.unfold,
                                                                  stride=1,
                                                                  padding=int((self.unfold - 1) / 2)).view_as(dist_tensor)

4.特征融合

scale_x_tensor = divisor_tensor * self.scale_x_net(dist_tensor * dist_tensor_unfold_x)

将全局差异性和局部特征相乘,实现信息融合

2.4 上采样

这部分也是图像中常见操作,其主要目的是恢复图像特征大小

直接附上代码

self.upsample_flow = torch.nn.ConvTranspose2d(in_channels=2,
                                                        out_channels=2,
                                                        kernel_size=4,
                                                        stride=2,
                                                        padding=1,
                                                        bias=False,
                                                        groups=2)

以上就是对这个光流神经网络的学习以及感悟,需要完整代码的可以与我联系

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

冰虺

万水千山总是情,给点打赏行不行

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值