图像修复-CVPR2023-Comprehensive and Delicate: An Efficient Transformer for Image Restoration
文章目录
transformer改进,该论文提出了一种高效的图像修复
Transformer
,通过捕捉超像素级的全局依赖性并将其传递到像素级,来提高计算效率。核心方法包括两个神经模块:凝聚注意力神经模块(CA)
通过特征聚合和注意力计算高效地捕捉超像素级的全局依赖,而双适应神经模块(DA)
则通过双重结构将超像素的全局信息适应性地传递到每个像素。实验表明,该方法在性能上与SwinIR
相当,但计算量仅为其约6%。
论文链接:Comprehensive and Delicate: An Efficient Transformer for Image
主要创新点
-
提出了通过超像素级依赖计算和转换来获取像素级全局依赖的方法,克服了现有高效 Transformer 仅捕捉像素级局部依赖的问题。
-
提出了图像修复 Transformer(CODE),由两个神经模块组成:
-
凝聚注意力神经模块(Condensed Attention Neural Block):通过特征聚合、注意力计算和特征恢复高效地捕捉超像素级的全局依赖。
-
双适应神经模块(Dual Adaptive Neural Block):采用创新的双重结构和动态加权方式,将超像素的全局信息分配到每个像素,实现像素级的全局依赖。
-
模型架构图
模型总体由一下部分组成:
- 整体的结构是一个类似于U-net的结构,先是进行下采样,通道数增加,宽高减少,然后进行上采样, C通道数减少,宽高逐渐增加的一个过程。
- 凝聚注意力神经模块(Condensed Attention Neural Block),为捕捉超像素级的全局依赖性,作者提出了一个由特征聚合、注意力计算和特征恢复组成的三步方法,并通过凝聚注意力模块(CA)实现。
- 双适应神经模块(Dual Adaptive Neural Block),凝聚注意力模块(CA)通过三步方法捕获了超像素级的全局依赖,使每个超像素特征包含丰富的全局信息。然而,为实现最终目标——像素级全局依赖性,还需要进一步处理。为此,提出了双适应神经模块(DA),通过创新的双路结构和动态加权方式,将超像素的全局信息转化为像素级全局依赖性。
下面将从源码层面剖析上面两个主要模块
源码剖析
凝聚注意力神经模块
也就是头和尾的过程,具体来说,给定输入特征
F∈H×W×C
,其中H×W
表示空间分辨率,C
为通道数,CA
首先在通道维度上对特征进行聚合,从而获得通道压缩后的特征eF∈H×W×C′
,其中C′
是压缩后的通道数。通过这一过程,冗余特征被减少,信息性较强的特征被保留,从而提高了计算效率。在后续步骤中,CA
会恢复聚合后的特征分布,使得输出的超像素特征与输入的像素特征在分辨率和通道数上保持一致。
源码
从源码中的
nn.Sequential
定义的ch_sp_squeeze
和sp_ch_unsqueeze
就是分别对应Feature Aggregation and Recovery
class CondensedAttentionNeuralBlock(nn.Module):
def __init__(self, embed_dim, squeezes, shuffle, expan_att_chans):
super(CondensedAttentionNeuralBlock, self).__init__()
self.embed_dim = embed_dim
sque_ch_dim = embed_dim // squeezes[0]
shuf_sp_dim = int(sque_ch_dim * (shuffle ** 2))
sque_sp_dim = shuf_sp_dim // squeezes[1]
self.sque_ch_dim = sque_ch_dim
self.shuffle = shuffle
self.shuf_sp_dim = shuf_sp_dim
self.sque_sp_dim = sque_sp_dim
# Feature Aggregation
self.ch_sp_squeeze = nn.Sequential(
nn.Conv2d(embed_dim, sque_ch_dim, 1),
nn.Conv2d(sque_ch_dim, sque_sp_dim, shuffle, shuffle, groups=sque_ch_dim)
)
self.channel_attention = ChannelAttention(sque_sp_dim, sque_ch_dim, expan_att_chans)
self.spatial_attention = SpatialAttention(sque_sp_dim, sque_ch_dim, expan_att_chans)
# Feature Recovery
self.sp_ch_unsqueeze = nn.Sequential(
nn.Conv2d(sque_sp_dim, shuf_sp_dim, 1, groups=sque_ch_dim),
nn.PixelShuffle(shuffle),
nn.Conv2d(sque_ch_dim, embed_dim, 1)
)
- 通道注意力:首先,沿通道维度计算注意力,捕捉通道级的依赖关系。
- 空间注意力:然后,沿空间维度计算注意力,捕捉空间级的依赖关系。
通过顺序执行通道和空间注意力计算,能够在通道和空间域中完全捕获全局依赖性。
为了提高计算效率,论文引入了一种新的通道切片与合并机制,该机制在多头注意力计算前后进行处理,主要步骤如下:
- 切片操作:首先,使用投影操作 Φproj(⋅)将每个通道的特征切片成多个子通道 Xji,扩展通道维度。
- 注意力计算:对切片后的通道进行多头注意力计算,首先将切片的通道划分为查询(Q)、键(K)和值(V),然后通过多头注意力机制进行计算。
- 合并操作:计算完注意力后,使用合并操作 Ψfuse(⋅) 将切片合并回原通道维度。
无论是通道注意力机制,还是空间注意力机制,基于原来都是在计算注意力机制图之后,再对整个特征图计算通道权重和空间位置权重,最终加权到注意力图上,完成注意力机制的计算,可以看到就是下图中到S的过程。
源码
import torch
import torch.nn as nn
import torch.nn.functional as F
# 通道注意力模块
class ChannelAttention(nn.Module):
def __init__(self, embed_dim, num_chans, expan_att_chans):
super(ChannelAttention, self).__init__()
# 扩展注意力通道数
self.expan_att_chans = expan_att_chans
# 计算头数
self.num_heads = int(num_chans * expan_att_chans)
# t是一个可学习的参数,用于调整注意力
self.t = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
# 使用卷积来分别生成q、k、v
self.group_qkv = nn.Conv2d(embed_dim, embed_dim * expan_att_chans * 3, 1, groups=embed_dim)
# 使用卷积来融合注意力后生成最终的输出
self.group_fus = nn.Conv2d(embed_dim * expan_att_chans, embed_dim, 1, groups=embed_dim)
def forward(self, x):
B, C, H, W = x.size()
# 将输入通过卷积生成q、k、v,并将其分开
q, k, v = self.group_qkv(x).view(B, C, self.expan_att_chans * 3, H, W).transpose(1, 2).contiguous().chunk(3, dim=1)
C_exp = self.expan_att_chans * C
# 将q、k、v进行reshape,方便后续的计算
q = q.view(B, self.num_heads, C_exp // self.num_heads, H * W)
k = k.view(B, self.num_heads, C_exp // self.num_heads, H * W)
v = v.view(B, self.num_heads, C_exp // self.num_heads, H * W)
# 对q和k进行归一化
q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
# 计算注意力得分(q和k的点积)
attn = q @ k.transpose(-2, -1) * self.t
# 对注意力得分进行softmax,得到最终的加权值
x_ = attn.softmax(dim=-1) @ v
# 将结果进行reshape和转置,准备进行最后的卷积
x_ = x_.view(B, self.expan_att_chans, C, H, W).transpose(1, 2).flatten(1, 2).contiguous()
# 通过卷积融合最终的输出
x_ = self.group_fus(x_)
return x_
# 空间注意力模块
class SpatialAttention(nn.Module):
def __init__(self, embed_dim, num_chans, expan_att_chans):
super(SpatialAttention, self).__init__()
# 扩展注意力通道数
self.expan_att_chans = expan_att_chans
# 计算头数
self.num_heads = int(num_chans * expan_att_chans)
# t是一个可学习的参数,用于调整注意力
self.t = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
# 使用卷积来分别生成q、k、v
self.group_qkv = nn.Conv2d(embed_dim, embed_dim * expan_att_chans * 3, 1, groups=embed_dim)
# 使用卷积来融合注意力后生成最终的输出
self.group_fus = nn.Conv2d(embed_dim * expan_att_chans, embed_dim, 1, groups=embed_dim)
def forward(self, x):
B, C, H, W = x.size()
# 将输入通过卷积生成q、k、v,并将其分开
q, k, v = self.group_qkv(x).view(B, C, self.expan_att_chans * 3, H, W).transpose(1, 2).contiguous().chunk(3, dim=1)
C_exp = self.expan_att_chans * C
# 将q、k、v进行reshape,方便后续的计算
q = q.view(B, self.num_heads, C_exp // self.num_heads, H * W)
k = k.view(B, self.num_heads, C_exp // self.num_heads, H * W)
v = v.view(B, self.num_heads, C_exp // self.num_heads, H * W)
# 对q和k进行归一化
q, k = F.normalize(q, dim=-2), F.normalize(k, dim=-2)
# 计算注意力得分(q和k的点积)
attn = q.transpose(-2, -1) @ k * self.t
# 对注意力得分进行softmax,得到最终的加权值
x_ = attn.softmax(dim=-1) @ v.transpose(-2, -1)
x_ = x_.transpose(-2, -1).contiguous()
# 将结果进行reshape和转置,准备进行最后的卷积
x_ = x_.view(B, self.expan_att_chans, C, H, W).transpose(1, 2).flatten(1, 2).contiguous()
# 通过卷积融合最终的输出
x_ = self.group_fus(x_)
return x_
双适应神经模块
主要的工作流程如下:
- 特征混合:通过点卷积
Pmix(⋅)
将超像素特征进行混合,获取混合后的特征 ̃F。- 双重结构:使用组卷积提取两个子特征切片
X1i
和X2i
,分别捕获短程依赖和像素级特征。- 动态加权:通过
Sigmoid
和GELU
激活函数进行动态加权,将超像素的全局信息融合到像素级别的特征中,生成加权后的特征Yi
。- 特征精炼:使用点卷积
Pfus(⋅)
对加权后的特征进行精炼,得到最终的像素级特征Z
。
源码
输入特征 通过一个 1x1 卷积 和一个 7x7 group convolution 进行特征混合,生成一个
embed_dim * 2
通道的特征图。将混合后的特征图分割为两部分 (
x0
和x1
)。对
x0
应用 GELU 激活,对x1
应用 Sigmoid 激活,并按元素相乘,得到加权后的特征。最后通过 1x1 卷积 对加权后的特征进行整合,得到最终输出。
class DualAdaptiveNeuralBlock(nn.Module):
def __init__(self, embed_dim):
super(DualAdaptiveNeuralBlock, self).__init__()
self.embed_dim = embed_dim
# 定义卷积层
self.group_conv = nn.Sequential(
# 第一层是 1x1 卷积,保持通道数不变
nn.Conv2d(embed_dim, embed_dim, 1),
# 第二层是 7x7 卷积,采用 group convolution,使得每个通道独立卷积
nn.Conv2d(embed_dim, embed_dim * 2, 7, 1, 3, groups=embed_dim) # groups=embed_dim 使得每个通道独立卷积
)
# 后处理卷积,整合输出特征,通道数恢复为 embed_dim
self.post_conv = nn.Conv2d(embed_dim, embed_dim, 1)
def forward(self, x):
# 获取输入 x 的形状,B 是批大小,C 是通道数,H 和 W 是高度和宽度
B, C, H, W = x.size()
# 通过 group_conv 得到两部分输出,x0 和 x1
# group_conv(x) 输出的形状是 (B, embed_dim * 2, H, W),
# 然后将其重塑为 (B, C, 2, H, W),然后通过 chunk(2, dim=2) 分割为两部分
x0, x1 = self.group_conv(x).view(B, C, 2, H, W).chunk(2, dim=2)
# 对 x0 应用 GELU 激活函数,对 x1 应用 Sigmoid 激活函数
# 然后按元素相乘(Hadamard product),得到加权后的特征 x_
x_ = F.gelu(x0.squeeze(2)) * torch.sigmoid(x1.squeeze(2))
# 使用 1x1 卷积对加权后的特征进行整合,输出形状为 (B, embed_dim, H, W)
x_ = self.post_conv(x_)
# 返回最终的特征
return x_