Pyramid Vision Transformer(PVT)

PyramidVisionTransformer(PVT)是一种新型的Transformer架构,专为密集预测任务设计,无需卷积。它采用4x4像素的细粒度图像块作为输入,学习高分辨率表示。通过渐进式收缩金字塔结构减少序列长度,显著降低计算成本,并使用空间减少注意力层进一步减少资源消耗。PVT的每个阶段都通过有步长的卷积降低分辨率,同时其SRA层在保持注意力计算效率的同时优化性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文链接:
[2102.12122] Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions (arxiv.org)

 创新点(出自原文)

 our PVT overcomes the difficulties of the conventional Transformer by (1) taking fine-grained image patches (i.e., 4×4 pixels per patch) as input to learn high-resolution representation, which is essential for dense prediction tasks; (2) introducing a progressive shrinking pyramid to reduce the sequence length of Transformer as the network deepens, significantly reducing the computational cost, and (3) adopting a spatial-reduction attention (SRA) layer to further reduce the resource consumption when learning high-resolution features.

PVT有四个stage,可以看到每个stage由patch emd和encoder组成。

1)分辨率是怎样降低的?

上图可以看出每经过一个stage,输入的分辨率都会降低,是patch emd的效果。那么是怎样降低分辨率的?

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
        #     f"img_size {img_size} should be divided by patch_size {patch_size}."
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        H, W = H // self.patch_size[0], W // self.patch_size[1]

        return x, (H, W)

关键代码: 

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

实际上是简单的有步长的卷积。例如H×W变成H/4×W/4,只需要stride为4就可以。

解释原文内容 (1) taking fine-grained image patches (i.e., 4×4 pixels per patch) as input to learn high-resolution representation, which is essential for dense prediction tasks

细粒度patch作为输入,意思是4×4 pixels作为一个token,而不是VIT中的16×16或者32×32 pixels作为一个token。

解释原文内容(2) introducing a progressive shrinking pyramid to reduce the sequence length of Transformer as the network deepens, significantly reducing the computational cost

每经过一次stage,空间分辨率都会降低,根据自注意力计算公式:计算量减小。

      

(详细解释可以看,作者写的非常详细。Swin-Transformer中MSA和W-MSA模块计算复杂度推导(非常详细,最新)_小周_的博客-CSDN博客

2)a spatial-reduction attention (SRA) layer是什么?

我们首先要清楚transformer中的自注意力是什么?序列长度是什么?特征个数是什么?

自注意力的输入序列的格式为(序列长度,特征个数),我们将其表示为(HW,C)。SRA指的是在自注意力Attention计算过程中,通过减小序列长度来减少计算量,也就是(3)adopting a spatial-reduction attention (SRA) layer to further reduce the resource consumption when learning high-resolution features.  序列长度实际上就是每个stage输入的数据大小。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

关键代码:

if sr_ratio > 1:
    self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)

从代码中可以看出,SRA与普通attention 的区别是,k和v的序列长度不再是输入的HW,而是根据stride的大小进行了缩小。为什么KV序列长度缩小减少资源消耗呢,还是移步看计算量的具体解释。

3)思考

虽然每经过一次stage,分辨率都会降低,但是对于每个stage来说,图像的粒度始终都是相同的。所以才出现了跨尺度Attention吧。

 

### Pyramid Vision Transformer (PVT) 介绍 Pyramid Vision Transformer (PVT) 是一种用于计算机视觉任务的新型骨干网络架构,旨在替代传统的卷积神经网络(CNN)[^1]。该模型通过多尺度特征提取来增强对不同大小目标物体的理解能力。 #### 整体结构概述 整体上,PVT采用了金字塔式的分层设计,每一层都由若干个Transformer编码器构成[^2]。这种层次化的构建方法允许模型在同一时间处理来自多个分辨率下的输入数据,从而提高了对于复杂场景下细节信息捕捉的能力。 ```python class PyramidVisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], depths=[3, 4, 6, 3], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=nn.LayerNorm, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, pretrained=None): super().__init__() self.num_stages = len(depths) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(self.num_stages): patch_embed = PatchEmbed( img_size=img_size if isinstance(img_size, int) else img_size[i], patch_size=patch_size if isinstance(patch_size, int) else patch_size[i], in_chans=in_chans if isinstance(in_chans, int) else in_chans[i], embed_dim=embed_dims[i]) block = nn.ModuleList([Block( dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j]) for j in range(depths[i])]) setattr(self, f"patch_embed{i + 1}", patch_embed) setattr(self, f"block{i + 1}", block) cur += depths[i] self.norm = norm_layer(embed_dims[-1]) ``` 此代码片段展示了`PyramidVisionTransformer`类初始化函数的一部分逻辑,定义了各个阶段中的嵌入维度、深度以及其他超参数设置[^3]。 #### 应用领域 除了作为图像分类的基础框架外,PVT还被广泛应用于其他多种密集预测任务中,比如语义分割、对象检测等。这些应用场景得益于其强大的表征学习能力和灵活性,在不依赖于传统CNN操作的情况下实现了优异的表现效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值