“不可能学不会的”--DINO讲解+代码

1 研究背景

自监督学习是一种无需标签即可训练深度学习模型的方法,它通过设计一些预训练任务来让模型学习数据的内在结构和表示。这种方法可以充分利用大量无标签的数据,提高模型的泛化能力。在计算机视觉领域,自监督学习逐渐受到重视,因为它能够解决标注数据稀缺的问题,并帮助模型学习更丰富的视觉特征。

​ Vision Transformer(ViT)是一种基于Transformer架构的视觉模型,它借鉴了自然语言处理(NLP)中的成功经验,将序列建模方法应用于图像识别任务。虽然ViT在某些任务上表现出了强大的性能,但它通常需要大量的计算资源和训练数据,且特征没有表现出独特的属性。因此,研究者们开始探索如何通过自监督学习来改进ViT,以克服其计算量大和数据依赖性强的问题。

​ DINO(Distillation with No Labels)是一种针对ViT架构的自监督学习方法,它由Facebook AI Research团队提出。DINO的核心思想是通过自蒸馏的方式,让学生网络模仿教师网络的行为,从而在无标签数据上学习有用的视觉特征。这种方法不仅提高了ViT的泛化能力,还使其能够在多种计算机视觉任务上表现出色。

2 算法流程

算法流程如图所示,一张经过预处理的图片 x x x,经过不同的图片增强得到 x 1 x_1 x1, x 2 x_2 x2 , 分别输入给学生模型 g θ s g\theta_s gθs 和教师模型 g θ t g\theta_t gθt 。两个网络具有相同的架构但参数不同。教师网络的输出通过一个 Batch 的平均值进行 centering 操作,每个网络输出一个 K K K维的特征,并使用 Softmax 进行归一化。然后使用交叉熵损失作为目标函数计算学生模型和教师模型之间的相似度。在教师模型上使用停止梯度算子来阻断梯度的传播,只把梯度传给学生模型使其更新参数。教师模型使用学生模型参数,通过指数移动平均(Exponential Moving Average, EMA) 方法进行更新。

DINO算法流程

3 简易代码实现

该部分代码主要是为读者理解该算法的核心思想,而不是源码复现。

导入所需要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision.transforms import Compose,RandomCrop,RandomHorizontalFlip,ToTensor,Normalize
from torch.utils.data import DataLoader
from torchvison.datasets import CIFAR10

timm库是一个用于深度学习的开源库,全称是 “PyTorch Image Models”,它基于Pytorch框架,该库由 Ross Wightman 创建并维护,旨在提供高效且易于使用的图像模型,包括大量预训练的模型和实用工具,我们通过该库调用VIT架构。

我们可以通过以下指令查看timm支持的所有与训练模型。

import timm
# 列出所有可用的模型
model_names = timm.list_models(pretrained=True)
print(model_names)

CIFAR10是一个常用的图像分类数据集,每张图片都是 3×32×32,3通道彩色图片,分辨率为 32×32。它包含了10个不同类别,每个类别有6000张图像,其中5000张用于训练,1000张用于测试。这10个类别分别为:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。CIFAR-10分类任务是将这些图像正确地分类到它们所属的类别中,本次流程以该代码块举例。

构建VIT的backbone

class VIT(nn.Module):
    def __init__(self,output_dim):
        super(ViT,self).__init__()
        self.vit = timm.create_model('vit_small_patch16_224',pretrained=False,num_classes=output_dim )
    def forward(self,x):
        return self.vit(x)

处理数据集图像数据

def get_dataloader(batch_size):
    #这些变换都是对单张图像进行的操作,它们不会改变数据集中图像的总数。
    transform = Compose(
        [
            #将图像尺寸变换到224x224,默认双线性插值
            Resize(224)
            #这一行添加了一个RandomCrop变换,它会在图像上随机裁剪出一个224x224的区域。padding=4参数意味着在裁剪之前,图像边缘会被填充4个像素的宽度,这有助于在裁剪时保持图像的多样性。
            RandomCrop(224,padding=4)
		   #每次调用都会以一定的概率(默认为 0.5)对图像进行水平翻转。如果翻转发生,图像将沿其垂直轴(从左到右)被镜像;如果翻转不发生,图像将保持不变。
            RandomHorizontalFlip()
            #会将PIL图像或NumPy数组转换为FloatTensor,并且把像素值从[0, 255]缩放到[0.0, 1.0]。
            ToTensor(),
            #Normalize变换用于标准化图像,即减去均值然后除以标准差。这里使用的均值和标准差都是(0.5, 0.5, 0.5),这意味着对于每个通道(红、绿、蓝),都会减去0.5然后除以0.5,将像素值标准化到[-1, 1]的范围内。
            Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    dataset = CIFAR10(root = './cifar10',train=True,transform=transform,download =True)
    return DataLoader(dataset,batch_size=batch_size,shuffle=True)   

ema

teacher_param.data = l * teacher_param.data + (1 - l) * student_param.data

损失函数

def H(t,s,C,tps,tpt):
    #冻结教师模型参数
    t = t.detach()
    s = F.softmax(s/tps,dim=1)
    t = F.softmax((t-C)/tpt,dim=1)
    return -(t * torch.log(s)).sum(dim=1).mean()

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision.transforms import Compose,RandomCrop,RandomHorizontalFlip,ToTensor,Normalize
from torch.utils.data import DataLoader
from torchvison.datasets import CIFAR10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class VIT(nn.Module):
    def __init__(self,output_dim):
        super(ViT,self).__init__()
        self.vit = timm.create_model('vit_small_patch16_224',pretrained=False,num_classes=output_dim )

    def forward(self,x):
        return self.vit(x)

def get_dataloader(batch_size):
    
    transform = Compose(
        [
            Resize(224)
            RandomCrop(224,padding=4)
            RandomHorizontalFlip()
            ToTensor(),
            Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    dataset = CIFAR10(root = './cifar10',train=True,transform=transform,download =True)
    return DataLoader(dataset,batch_size=batch_size,shuffle=True)   

output_dim =128
gs = ViT(output_dim).to(device)
gt = ViT(output_dim).to(device)
gt.load_state_dict(gs.state_dict())
C =torch.zeros(output_dim,device = device)

def H(t,s,C,tps,tpt):
    t = t.detach()
    s = F.softmax(s/tps,dim=1)
    t = F.softmax((t-C)/tpt,dim=1)
    return -(t * torch.log(s)).sum(dim=1).mean()
    
def augement(x):
    return x + 0.1*torch.randn_like(x)

batch_size = 64
#平滑softmax用的参数
tps = 0.1
tpt = 0.07
#l是lamda
l = 0.6
m = 0.5

optimizer = optim.SGD(gs.parameters(),lr =0.03,momentum=0.9,weight_decay=5e-4)
#training
loader = get_dataloader(batch_size)
for x, _ in loader:  # We don't need labels for self-supervised learning
    x = x.to(device) # (B,3,224,224)
    x1, x2 = augment(x), augment(x)  # random views
    s1, s2 = gs(x1), gs(x2)  # student output, (B,128)
    t1, t2 = gt(x1), gt(x2)  # teacher output, (B,128)
    loss = H(t1, s2, C, tps, tpt)/2 + H(t2, s1, C, tps, tpt)/2 # divide by 2 for combined loss
    loss.backward()  # back-propagate
    optimizer.step()  # SGD update for student
    optimizer.zero_grad()  # Clear gradients

    # Teacher and center updates
    with torch.no_grad():
        for teacher_param, student_param in zip(gt.parameters(), gs.parameters()):
            teacher_param.data = l * teacher_param.data + (1 - l) * student_param.data
        C = m * C + (1 - m) * torch.cat([t1, t2]).mean(dim=0)

print("Training complete!")

4 参考资料

论文地址:https//arxiv.org/pdf/2104.14294.pdf

代码地址:https//github.com/facebookresearch/dino

参考资料:https://mp.weixin.qq.com/s/mnDNQYlNcGZEgOphrusrtQ

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值