在当今数据驱动的计算机视觉领域,自监督学习正以其独特的优势引发变革。传统的监督学习依赖大量人工标注数据,不仅成本高昂,且在数据稀缺场景下受限严重。而自监督学习通过巧妙利用数据本身的结构和特征,无需人工干预即可自动创建标签,大大降低了数据获取成本,尤其适用于医学影像等标注困难的领域。
自监督学习在预训练阶段学习到的特征表示具有广泛的适用性,能够显著提高模型的泛化能力,使其在图像分类、目标检测等下游任务中表现优异。同时,它还能增强模型对数据分布变化的适应性,提升鲁棒性和稳定性。
总之,自监督学习正以其独特的优势推动计算机视觉技术迈向新的高度,为解决复杂视觉问题提供了全新的思路和方法。而当前应用最广泛的自监督模型-DINOv2,近年来已经被许多学者追捧,并且研究成果已经出现在Nature,Science等顶级期刊上。但是网上关于DINOv2的教程却少之又少,且DINOv2从下载源码到训练测试,存在许多坑点。作者也是踩坑无数,终于在自己的服务器上将其成功训练并测试。希望接下来的分享能够帮助到各位,如果有问题,请批评指针,感谢各位。
1. 环境配置
(1)使用Anaconda创建并激活python环境
conda create -n dinov2_cls python=3.11 -y
conda activate dinov2_cls
(2)下载torch-gpu和torchvision-gpu (下载网址:torch|torchvision)
# 作者在Linux下运行代码的,如果在Windows下,则下载相应Windows版本即可
torch-2.0.0+cu117-cp311-cp311-linux_x86_64.whl
torchvision-0.15.0+cu117-cp311-cp311-linux_x86_64.whl
需要注意的是,torch和torchvision的版本需要保持兼容。以下为对照表:
有时使用pip install命令时,会出现timeout等错误,此时需要添加国内源,例如清华源等,具体使用如下代码,使默认的安装源为清华源,之后pip install则会安装得很快。
python -m pip install --upgrade pip
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
安装torch、torchvision和其他包
pip install torch-2.0.0+cu117-cp311-cp311-linux_x86_64.whl
pip install torchvision-0.15.0+cu117-cp311-cp311-linux_x86_64.whl
pip install pandas
pip install seaborn
pip install scikit-learn
其他的包如果不存在,缺哪个则pip install安装哪个即可。
2. 分类数据下载
代码中使用的数据集是kaggle中的Intel Image Dataset,其下载地址为:Intel Image Dataset,下载好该数据集还需要进行预处理,即划分train、val和test数据集,读者可以自行写程序随机划分,这里贴上我已经划分好的数据下载地址。
通过网盘分享的文件:dinov2_format_data.zip
链接: https://pan.baidu.com/s/1oekwnQBs04qqibHKisCoHw 提取码: r8ys
# 里面没有examples文件夹,需要作者自己提供
3. 下游任务-分类
导入相关包
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, models, transforms
导入模型(必须能连接外网,否则导入模型会失败)
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vits14 = dinov2_vits14.to(device)
可视化相关参数设置
# 可视化
patch_size = dinov2_vits14.patch_size
patch_h = 520 // patch_size
patch_w = 520 // patch_size
feat_dim = 384
transform1 = transforms.Compose([
transforms.Resize(520),
transforms.CenterCrop(518),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.2)
])
total_features = []
EXAMPLE_PATH = '../dinov2_format_data/classification'
# 需要准备一个examples文件夹,里面放一些图像,例如dog和cat
folder_path = f'{EXAMPLE_PATH}/examples/'
可视化图像显示
with torch.no_grad():
list_img_path = os.listdir(folder_path)[0:4]
for img_path in list_img_path:
img_path = os.path.join(folder_path, img_path)
img = Image.open(img_path).convert('RGB')
img_t = transform1(img).to(device)
features_dict = dinov2_vits14.forward_features(img_t.unsqueeze(0))
features = features_dict['x_norm_patchtokens']
total_features.append(features)
total_features = torch.cat(total_features, dim=0)
# PCA分离背景
total_features = total_features.reshape(4 * patch_h * patch_w, feat_dim)
total_features = total_features.cpu()
pca = PCA(n_components=3)
scaler = MinMaxScaler(clip=True)
pca.fit(total_features)
pca_features = pca.transform(total_features)
plt.subplot(2, 2, 1)
plt.hist(pca_features[:, 0])
plt.subplot(2, 2, 2)
plt.hist(pca_features[:, 1])
plt.subplot(2, 2, 3)
plt.hist(pca_features[:, 2])
plt.show()
plt.close()
# min_max缩放
pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / \
(pca_features[:, 0].max() - pca_features[:, 0].min())
# pca_features = sklearn.processing.minmax_scale(pca_features)
for i in range(4):
plt.subplot(2, 2, i+1)
plt.imshow(pca_features[i*patch_h*patch_w : (i+1)*patch_h*patch_w, 0].reshape(patch_h, patch_w))
plt.show()
# 分离前景和背景
pca_features_bg = pca_features[:, 0] < 0.5
pca_features_fg = ~pca_features_bg
for i in range(4):
plt.subplot(2, 2, i+1)
plt.imshow(pca_features_bg[i * patch_h * patch_w: (i+1) * patch_h * patch_w].reshape(patch_h, patch_w))
plt.show()
pca.fit(total_features[pca_features_fg])
pca_features_left = pca.transform(total_features[pca_features_fg])
for i in range(3):
# min_max scaling
pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min())
pca_features_rgb = pca_features.copy()
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_left
pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
for i in range(4):
plt.subplot(2, 2, i+1)
plt.imshow(pca_features_rgb[i])
plt.show()
可视化结果图:
(1)examples文件夹中图像
(2)主成分分布直方图,用来展示特征分布
(3)主成分灰度图,可视化特征空间中最显著的语义维度,定位图像中的显著区域(物体vs背景)。可以看到,dinov2能够有效地提取到关键特征。
(4)背景分割二值掩码,实现无监督的背景-前景分离,验证模型对场景结构的理解能力。
(5)前景伪彩色图,用颜色编码揭示前景物体内部的语义结构,展示不同部位的区分度。
dinov2在自定义数据集上的微调
# 准备数据
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
ROOT_PATH = '../dinov2_format_data/classification/root/'
image_datasets = {
x: datasets.ImageFolder(os.path.join(ROOT_PATH, x), data_transforms[x])
for x in ['train', 'val']
}
batch_size = 64
num_workers = 4
data_loaders = {x: DataLoader(image_datasets[x], shuffle=True, batch_size=batch_size, num_workers=4)
for x in ['train', 'val']
}
class_names = image_datasets['train'].classes
print(class_names)
# 创建模型分类头
class DinoVisionTransformerClassifier(nn.Module):
def __init__(self):
super(DinoVisionTransformerClassifier, self).__init__()
self.transformer = dinov2_vits14
self.classifier = nn.Sequential(
nn.Linear(384, 256),
nn.ReLU(),
nn.Linear(256, len(class_names))
)
def forward(self, x):
x = self.transformer(x)
x = self.transformer.norm(x)
x = self.classifier(x)
return x
model = DinoVisionTransformerClassifier()
model = model.to(device)
# 损失函数的优化器定义
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000001)
# Train
num_epoch = 3
for epoch in range(num_epoch):
train_acc = 0
train_loss = 0
loop = tqdm(data_loaders['train'])
for idx, (features, labels) in enumerate(loop):
features = features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
predictions = outputs.argmax(dim=1, keepdim=True).squeeze()
correct = (predictions == labels).sum().item()
accuracy = correct / batch_size
loss.backward()
optimizer.step()
loop.set_description(f"Epoch [{epoch}/{num_epoch}]")
loop.set_postfix(loss=loss.item(), acc=accuracy)
correct = 0
total = 0
val_predicted = []
val_labels = []
with torch.no_grad():
for features, labels in data_loaders["val"]:
features = features.to(device)
labels = labels.to(device)
outputs = model(features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted.to(device) == labels).sum().item()
val_labels += (labels.cpu().numpy().tolist())
val_predicted += (predicted.cpu().numpy().tolist())
print(f'Accuracy of the network on the {len(data_loaders["val"])*6} val images: {100 * correct // total} %')
print(classification_report(val_labels, val_predicted, target_names=class_names))
cm = confusion_matrix(val_labels, val_predicted)
df_cm = pd.DataFrame(
cm,
index = class_names,
columns = class_names
)
def show_confusion_matrix(confusion_matrix):
hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
plt.ylabel("Surface Ground Truth")
plt.xlabel("Predicted Surface")
plt.legend()
show_confusion_matrix(df_cm)
训练结果展示,可以看到,经过3个epoch的微调,在验证集上的准确率就达到了93%。
混淆矩阵
以上代码为使用单机单卡训练并测试的,单机八卡的代码如下:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
import torch.distributed as dist
import torchvision
from torchvision import datasets, models, transforms
# 设置环境变量
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# 启用分布式调试信息
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
def setup(rank, world_size):
torch.cuda.set_device(rank)
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size
)
torch.distributed.barrier()
def cleanup():
dist.destroy_process_group()
def main_worker(rank, world_size, args):
# 设置GPU
setup(rank, world_size)
# 设置随机种子保证各卡一致性
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# 数据加载部分 - 仅在主进程打印信息
if rank == 0:
print(f"使用 {world_size} 个GPU进行训练")
# 数据处理
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
ROOT_PATH = args.data_path
image_datasets = {
x: datasets.ImageFolder(os.path.join(ROOT_PATH, x), data_transforms[x])
for x in ['train', 'val']
}
if rank == 0:
print(f"训练集大小: {len(image_datasets['train'])}")
print(f"验证集大小: {len(image_datasets['val'])}")
# 创建分布式采样器
train_sampler = DistributedSampler(
image_datasets['train'],
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.seed
)
val_sampler = DistributedSampler(
image_datasets['val'],
num_replicas=world_size,
rank=rank,
shuffle=False
)
# 调整batch_size为每卡的大小
per_gpu_batch_size = args.batch_size // world_size
data_loaders = {
'train': DataLoader(
image_datasets['train'],
batch_size=per_gpu_batch_size,
num_workers=args.num_workers,
pin_memory=True,
sampler=train_sampler
),
'val': DataLoader(
image_datasets['val'],
batch_size=per_gpu_batch_size,
num_workers=args.num_workers,
pin_memory=True,
sampler=val_sampler
)
}
class_names = image_datasets['train'].classes
if rank == 0:
print(f"类别名称: {class_names}")
# 加载预训练模型
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# 创建模型
class DinoVisionTransformerClassifier(nn.Module):
def __init__(self):
super(DinoVisionTransformerClassifier, self).__init__()
self.transformer = dinov2_vits14
self.classifier = nn.Sequential(
nn.Linear(384, 256),
nn.ReLU(),
nn.Linear(256, len(class_names))
)
def forward(self, x):
x = self.transformer(x)
# 确保只使用分类token
if isinstance(x, dict):
x = x['x_norm_clstoken']
x = self.classifier(x)
return x
model = DinoVisionTransformerClassifier().to(rank)
# 多卡包装,添加find_unused_parameters=True参数
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
find_unused_parameters=True
)
# 优化器和损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# 学习率调度器
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# 训练循环
best_acc = 0.0
for epoch in range(args.epochs):
# 设置采样器epoch,确保每个epoch的shuffle不同
train_sampler.set_epoch(epoch)
# 训练阶段
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
if rank == 0:
loop = tqdm(data_loaders['train'], desc=f"Epoch [{epoch}/{args.epochs}]")
else:
loop = data_loaders['train']
for idx, (features, labels) in enumerate(loop):
features = features.to(rank, non_blocking=True)
labels = labels.to(rank, non_blocking=True)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += labels.size(0)
train_correct += predicted.eq(labels).sum().item()
if rank == 0:
loop.set_postfix(loss=train_loss/(idx+1), acc=100.*train_correct/train_total)
# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
val_predicted = []
val_labels_list = []
with torch.no_grad():
for features, labels in data_loaders['val']:
features = features.to(rank, non_blocking=True)
labels = labels.to(rank, non_blocking=True)
outputs = model(features)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += labels.size(0)
val_correct += predicted.eq(labels).sum().item()
# 收集预测结果用于后续分析
val_predicted.append(predicted.cpu())
val_labels_list.append(labels.cpu())
# 收集所有GPU的结果
val_loss = torch.tensor(val_loss).to(rank)
val_correct = torch.tensor(val_correct).to(rank)
val_total = torch.tensor(val_total).to(rank)
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(val_correct, op=dist.ReduceOp.SUM)
dist.all_reduce(val_total, op=dist.ReduceOp.SUM)
val_loss = val_loss.item() / len(data_loaders['val'])
val_acc = 100. * val_correct.item() / val_total.item()
# 仅在主进程打印结果
if rank == 0:
print(f"Epoch [{epoch}/{args.epochs}] 验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%")
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.module.state_dict(), 'best_model.pth')
print(f"模型已保存,准确率: {best_acc:.2f}%")
# 更新学习率
scheduler.step()
# 评估模型
if rank == 0:
print(f"最佳验证准确率: {best_acc:.2f}%")
# 加载最佳模型
model.module.load_state_dict(torch.load('best_model.pth'))
model.eval()
# 收集所有预测结果
all_predicted = []
all_labels = []
with torch.no_grad():
for features, labels in data_loaders['val']:
features = features.to(rank)
labels = labels.to(rank)
outputs = model(features)
_, predicted = outputs.max(1)
all_predicted.append(predicted.cpu())
all_labels.append(labels.cpu())
all_predicted = torch.cat(all_predicted).numpy()
all_labels = torch.cat(all_labels).numpy()
# 打印分类报告
print(classification_report(all_labels, all_predicted, target_names=class_names))
# 绘制混淆矩阵
cm = confusion_matrix(all_labels, all_predicted)
df_cm = pd.DataFrame(
cm,
index=class_names,
columns=class_names
)
def show_confusion_matrix(confusion_matrix):
hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
plt.ylabel("真实标签")
plt.xlabel("预测标签")
plt.savefig('confusion_matrix.png')
show_confusion_matrix(df_cm)
# 清理进程组
cleanup()
def main():
import argparse
parser = argparse.ArgumentParser(description='DINOv2 多卡训练')
parser.add_argument('--data-path', default='../dinov2_format_data/classification/root/', help='数据集路径')
parser.add_argument('--batch-size', default=64, type=int, help='总批量大小')
parser.add_argument('--epochs', default=3, type=int, help='训练轮数')
parser.add_argument('--lr', default=0.00001, type=float, help='学习率')
parser.add_argument('--num-workers', default=4, type=int, help='数据加载进程数')
parser.add_argument('--seed', default=42, type=int, help='随机种子')
args = parser.parse_args()
# 设置GPU数量
world_size = torch.cuda.device_count()
print(f"发现 {world_size} 个GPU")
# 启动多进程训练
mp.spawn(
main_worker,
args=(world_size, args),
nprocs=world_size,
join=True
)
if __name__ == "__main__":
main()
训练时GPU的状态
4. 下游任务-语义分割
该程序训练验证及测试数据集使用的是voc2012数据集,读者可自行下载,下载网址为:voc2012,在当前文件夹中创建data和output文件夹,将下载好的voc2012数据集放入data文件夹中,而output文件夹则默认保存最后的结果,最后的数据组织形式为:
训练代码
import json
import logging
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dino_finetune import (
DINOV2EncoderLoRA,
get_dataloader,
visualize_overlay,
compute_iou_metric,
)
def validate_epoch(
dino_lora: nn.Module,
val_loader: DataLoader,
criterion: nn.CrossEntropyLoss,
metrics: dict,
) -> None:
val_loss = 0.0
val_iou = 0.0
dino_lora.eval()
with torch.no_grad():
for images, masks in val_loader:
images = images.float().cuda()
masks = masks.long().cuda()
logits = dino_lora(images)
loss = criterion(logits, masks)
val_loss += loss.item()
y_hat = torch.sigmoid(logits)
iou_score = compute_iou_metric(y_hat, masks, ignore_index=255)
val_iou += iou_score.item()
metrics["val_loss"].append(val_loss / len(val_loader))
metrics["val_iou"].append(val_iou / len(val_loader))
def finetune_dino(config: argparse.Namespace, encoder: nn.Module):
dino_lora = DINOV2EncoderLoRA(
encoder=encoder,
r=config.r,
emb_dim=config.emb_dim,
img_dim=config.img_dim,
n_classes=config.n_classes,
use_lora=config.use_lora,
use_fpn=config.use_fpn,
).cuda()
if config.lora_weights:
dino_lora.load_parameters(config.lora_weights)
train_loader, val_loader = get_dataloader(
config.dataset, img_dim=config.img_dim, batch_size=config.batch_size
)
# Finetuning for segmentation
criterion = nn.CrossEntropyLoss(ignore_index=255).cuda()
optimizer = optim.AdamW(dino_lora.parameters(), lr=config.lr)
# Log training and validation metrics
metrics = {
"train_loss": [],
"val_loss": [],
"val_iou": [],
}
for epoch in range(config.epochs):
dino_lora.train()
for images, masks in train_loader:
images = images.float().cuda()
masks = masks.long().cuda()
optimizer.zero_grad()
logits = dino_lora(images)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
if epoch % 5 == 0:
y_hat = torch.sigmoid(logits)
validate_epoch(dino_lora, val_loader, criterion, metrics)
dino_lora.save_parameters(f"output/{config.exp_name}.pt")
if config.debug:
# Visualize some of the batch and write to files when debugging
visualize_overlay(
images, y_hat, config.n_classes, filename=f"viz_{epoch}"
)
logging.info(
f"Epoch: {epoch} - val IoU: {metrics['val_iou'][-1]} "
f"- val loss {metrics['val_loss'][-1]}"
)
# Log metrics & save model the final values
# Saves only loRA parameters and classifer
dino_lora.save_parameters(f"output/{config.exp_name}.pt")
with open(f"output/{config.exp_name}_metrics.json", "w") as f:
json.dump(metrics, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Experiment Configuration")
parser.add_argument(
"--exp_name",
type=str,
default="small_lora",
help="Experiment name",
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug by visualizing some of the outputs to file for a sanity check",
)
parser.add_argument(
"--r",
type=int,
default=3,
help="loRA rank parameter r",
)
parser.add_argument(
"--size",
type=str,
default="small",
help="DINOv2 backbone parameter [small, base, large, giant]",
)
parser.add_argument(
"--use_lora",
action="store_true",
help="Use Low-Rank Adaptation (LoRA) to finetune",
)
parser.add_argument(
"--use_fpn",
action="store_true",
help="Use the FPN decoder for finetuning",
)
parser.add_argument(
"--img_dim",
type=int,
nargs=2,
default=(224, 224),
help="Image dimensions (height width)",
)
parser.add_argument(
"--lora_weights",
type=str,
default=None,
help="Load the LoRA weights from file location",
)
# Training parameters
parser.add_argument(
"--dataset",
type=str,
default="voc",
help="The dataset to finetune on, either `voc` or `ade20k`",
)
parser.add_argument(
"--epochs",
type=int,
default=5,
help="Number of training epochs",
)
parser.add_argument(
"--lr",
type=float,
default=3e-4,
help="Learning rate",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Finetuning batch size",
)
config = parser.parse_args()
# All backbone sizes and configurations
backbones = {
"small": "vits14_reg",
"base": "vitb14_reg",
"large": "vitl14_reg",
"giant": "vitg14_reg",
}
embedding_dims = {
"small": 384,
"base": 768,
"large": 1024,
"giant": 1536,
}
config.emb_dim = embedding_dims[config.size]
# Dataset
dataset_classes = {
"voc": 21,
"ade20k": 150,
}
config.n_classes = dataset_classes[config.dataset]
encoder = torch.hub.load(
repo_or_dir="facebookresearch/dinov2",
model=f"dinov2_{backbones[config.size]}",
).cuda()
finetune_dino(config, encoder)
训练脚本为,读者可自行设定相关参数:
python main.py --exp_name base_voc --dataset voc --size base --use_lora --img_dim 308 308 --epochs 50 --use_fpn
训练完成后,在output文件夹中会生成.pt文件和,.json文件,可以对.pt文件进行加载进行测试,而.json文件中则存放训练过程中相关指标的变化。
由于原作者只给了单机单卡的代码,在其基础上我做了改进,使其可以使用单机八卡进行训练,其代码和训练脚本如下:
import json
import logging
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dino_finetune import (
DINOV2EncoderLoRA,
get_dataloader,
visualize_overlay,
compute_iou_metric,
)
def validate_epoch(
dino_lora: nn.Module,
val_loader: DataLoader,
criterion: nn.CrossEntropyLoss,
metrics: dict,
rank: int,
world_size: int
) -> None:
val_loss = 0.0
val_iou = 0.0
total_batches = 0
dino_lora.eval()
with torch.no_grad():
for images, masks in val_loader:
images = images.float().to(rank)
masks = masks.long().to(rank)
logits = dino_lora(images)
loss = criterion(logits, masks)
val_loss += loss.item()
y_hat = torch.sigmoid(logits)
iou_score = compute_iou_metric(y_hat, masks, ignore_index=255)
val_iou += iou_score.item()
total_batches += 1
# 汇总所有进程的指标
total_batches_tensor = torch.tensor(total_batches, device=rank)
dist.all_reduce(total_batches_tensor, op=dist.ReduceOp.SUM)
total_batches_all = total_batches_tensor.item()
val_loss_tensor = torch.tensor(val_loss, device=rank)
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
val_loss_all = val_loss_tensor.item()
val_iou_tensor = torch.tensor(val_iou, device=rank)
dist.all_reduce(val_iou_tensor, op=dist.ReduceOp.SUM)
val_iou_all = val_iou_tensor.item()
avg_val_loss = val_loss_all / total_batches_all
avg_val_iou = val_iou_all / total_batches_all
metrics["val_loss"].append(avg_val_loss)
metrics["val_iou"].append(avg_val_iou)
def finetune_dino(config: argparse.Namespace, encoder: nn.Module, rank: int, world_size: int):
dino_lora = DINOV2EncoderLoRA(
encoder=encoder,
r=config.r,
emb_dim=config.emb_dim,
img_dim=config.img_dim,
n_classes=config.n_classes,
use_lora=config.use_lora,
use_fpn=config.use_fpn,
).to(rank)
if config.lora_weights:
dino_lora.load_parameters(config.lora_weights)
# 使用DDP包装模型
dino_lora = DDP(dino_lora, device_ids=[rank])
train_loader_orig, val_loader_orig = get_dataloader(
config.dataset, img_dim=config.img_dim, batch_size=config.batch_size
)
# 创建分布式采样器
train_sampler = DistributedSampler(
train_loader_orig.dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
val_sampler = DistributedSampler(
val_loader_orig.dataset,
num_replicas=world_size,
rank=rank,
shuffle=False
)
# 创建分布式数据加载器
train_loader = DataLoader(
train_loader_orig.dataset,
batch_size=config.batch_size,
sampler=train_sampler,
num_workers=train_loader_orig.num_workers,
pin_memory=train_loader_orig.pin_memory,
drop_last=True
)
val_loader = DataLoader(
val_loader_orig.dataset,
batch_size=config.batch_size,
sampler=val_sampler,
num_workers=val_loader_orig.num_workers,
pin_memory=val_loader_orig.pin_memory,
drop_last=False
)
# 优化器和损失函数
criterion = nn.CrossEntropyLoss(ignore_index=255).to(rank)
optimizer = optim.AdamW(dino_lora.parameters(), lr=config.lr)
metrics = {
"train_loss": [],
"val_loss": [],
"val_iou": [],
}
for epoch in range(config.epochs):
# 设置epoch使shuffle生效
train_sampler.set_epoch(epoch)
dino_lora.train()
for images, masks in train_loader:
images = images.float().to(rank)
masks = masks.long().to(rank)
optimizer.zero_grad()
logits = dino_lora(images)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
if epoch % 2 == 0:
# 验证并聚合指标
validate_epoch(dino_lora, val_loader, criterion, metrics, rank, world_size)
# 只在主进程保存模型和日志
if rank == 0:
# 保存原始模型(非DDP包装)
dino_lora.module.save_parameters(f"output/{config.exp_name}.pt")
if config.debug:
y_hat = torch.sigmoid(logits)
visualize_overlay(
images, y_hat, config.n_classes, filename=f"viz_{epoch}"
)
logging.info(
f"Epoch: {epoch} - val IoU: {metrics['val_iou'][-1]} "
f"- val loss {metrics['val_loss'][-1]}"
)
# 最终保存(主进程)
if rank == 0:
dino_lora.module.save_parameters(f"output/{config.exp_name}.pt")
with open(f"output/{config.exp_name}_metrics.json", "w") as f:
json.dump(metrics, f)
if __name__ == "__main__":
# 初始化分布式环境
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
parser = argparse.ArgumentParser(description="Experiment Configuration")
parser.add_argument(
"--exp_name",
type=str,
default="lora",
help="Experiment name",
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug by visualizing some of the outputs to file for a sanity check",
)
parser.add_argument(
"--r",
type=int,
default=3,
help="loRA rank parameter r",
)
parser.add_argument(
"--size",
type=str,
default="base",
help="DINOv2 backbone parameter [small, base, large, giant]",
)
parser.add_argument(
"--use_lora",
action="store_true",
help="Use Low-Rank Adaptation (LoRA) to finetune",
)
parser.add_argument(
"--use_fpn",
action="store_true",
help="Use the FPN decoder for finetuning",
)
parser.add_argument(
"--img_dim",
type=int,
nargs=2,
default=(490, 490),
help="Image dimensions (height width)",
)
parser.add_argument(
"--lora_weights",
type=str,
default=None,
help="Load the LoRA weights from file location",
)
# Training parameters
parser.add_argument(
"--dataset",
type=str,
default="voc",
help="The dataset to finetune on, either `voc` or `ade20k`",
)
parser.add_argument(
"--epochs",
type=int,
default=20,
help="Number of training epochs",
)
parser.add_argument(
"--lr",
type=float,
default=3e-4,
help="Learning rate",
)
parser.add_argument(
"--batch_size",
type=int,
default=96,
help="Finetuning batch size",
)
config = parser.parse_args()
# 设置日志(仅主进程)
if rank == 0:
logging.basicConfig(level=logging.INFO)
# 模型配置
backbones = {
"small": "vits14_reg",
"base": "vitb14_reg",
"large": "vitl14_reg",
"giant": "vitg14_reg",
}
embedding_dims = {
"small": 384,
"base": 768,
"large": 1024,
"giant": 1536,
}
config.emb_dim = embedding_dims[config.size]
dataset_classes = {
"voc": 21,
"ade20k": 150,
}
config.n_classes = dataset_classes[config.dataset]
# 加载模型(所有进程)
encoder = torch.hub.load(
repo_or_dir="facebookresearch/dinov2",
model=f"dinov2_{backbones[config.size]}",
).to(rank)
# 开始训练
finetune_dino(config, encoder, rank, world_size)
# 清理分布式环境
dist.destroy_process_group()
torchrun --nproc_per_node=8 main_ddp.py \
--exp_name base_voc \
--dataset voc \
--size base \
--use_lora \
--img_dim 308 308 \
--epochs 50 \
--use_fpn
所有的代码下载地址
通过网盘分享的文件:dinov2_seg.zip
链接: https://pan.baidu.com/s/1x0I1IT3H5VCyvmN9ZX2HXg 提取码: a8uk
参考
Finetuning DINOv2 with LoRA for Image Segmentation
总结
(1) 转载麻烦请说明出处。
(2) 如果有问题,可以评论区留下您的问题,会及时进行回复。
(3) 麻烦各位大佬动动你们的小手指,一键三连,谢谢谢谢~