StarNet
论文链接:[2403.19967] Rewrite the Stars
github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars
CVPR2024 Rewrite the Stars论文揭示了star operation
(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力。基于此提出了StarNet
,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟。
优势 (Advantages)
-
高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)
- StarNet通过星操作(star operation)实现高维和非线性特征空间的映射,而无需增加计算复杂度。与传统的内核技巧(kernel tricks)类似,星操作能够在低维输入中隐式获得高维特征 (ar5iv)。
- 对于YOLO系列网络,这意味着在保持计算效率的同时,能够获得更丰富和表达力更强的特征表示,这对于目标检测任务中的精细特征捕获尤为重要。
-
高效网络设计 (Efficient Network Design)
- StarNet通过星操作实现了高效的特征表示,无需复杂的网络设计和额外的计算开销。其独特的能力在于能够在低维空间中执行计算,但隐式地考虑极高维的特征 (ar5iv)。
- 这使得StarNet可以作为YOLO系列网络的主干,提供高效的计算和更好的特征表示,有助于在资源受限的环境中实现更高的检测性能。
-
多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)
- 通过多层星操作,StarNet能够递归地增加隐式特征维度,接近无限维度。对于具有较大宽度和深度的网络,这种特性可以显著增强特征的表达能力 (ar5iv)。
- 对于YOLO系列网络,这意味着可以通过适当的深度和宽度设计,显著提高特征提取的质量,从而提升目标检测的准确性。
解决的问题 (Problems Addressed)
-
计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)
- StarNet通过星操作在保持计算复杂度较低的同时,实现了高维特征空间的映射。这解决了传统高效网络设计中计算复杂度与性能之间的权衡问题 (ar5iv)。
- YOLO系列网络需要在实时性和检测精度之间找到平衡,StarNet的高效特性正好契合这一需求。
-
特征表示的丰富性 (Richness of Feature Representation)
- 传统卷积网络在特征表示的高维非线性变换上存在一定局限性,而StarNet通过星操作实现了更丰富的特征表示 (ar5iv)。
- 在目标检测任务中,特别是对于小目标和复杂场景,丰富的特征表示能够显著提升检测效果,使得YOLO系列网络在这些场景中表现更佳。
-
简化网络设计 (Simplified Network Design)
- StarNet通过星操作提供了一种简化网络设计的方法,无需复杂的特征融合和多分支设计就能实现高效的特征表示 (ar5iv)。
- 对于YOLO系列网络,这意味着可以更容易地设计和实现高效的主干网络,降低设计和调试的复杂度。
在MMYOLO中将StarNet替换成yolov5的主干网络
1. 在上文提到的仓库中下载imagenet/starnet.py
2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量
3. 将class StarNet注册并且在__init__()函数中进行修改
4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数
修改后的starnet.py
"""
Implementation of Prof-of-Concept Network: StarNet.
We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
- like NO layer-scale in network design,
- and NO EMA during training,
- which would improve the performance further.
Created by: Xu Ma (Email: [email protected])
Modified Date: Mar/29/2024
"""
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from typing import List, Sequence, Union
# from timm.models.registry import register_model
from mmyolo.registry import MODELS
model_urls = {
"starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",
"starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",
"starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",
"starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",
}
class ConvBN(torch.nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True):
super().__init__()
self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding,