项目链接:
code:https://github.com/cvg/lightglue
Paper: https://arxiv.org/pdf/2306.13643.pdf
本文代码介绍:
- 模板图:有多张模板图,可能是不同场景的图,也可能是同一物体在3D各视角下拍摄的图
- 测试图:有N多张待匹配图像,无法人眼看出哪张模板图才是当前测试图的最佳匹配图
- 功能:
a. 实现最佳模板图寻找策略
b. 实现模板bbox复用在测试图中
c. 实现示例级目标匹配
d. 匹配结果可视化
本文代码流程:
- 利用lightglue寻找当前测试图像的最佳模板【既关键点匹配成功最多的图】
- 利用lightglue进行关键点匹配
- 寻找和bbox的中心点距离最近的关键点坐标
- 复用模板图中的bbox
- 测试图匹配结果可视化
准备工作:
代码下载
git clone https://github.com/cvg/LightGlue.git && cd LightGlue
环境配置
conda create -n lightglue python=3.11
python -m pip install --upgrade pip
python -m pip install -e .
# 下面的根据自己需要下载,有就不用下了~
pip3 install opencv-python
pip3 install kornia
#根据自己的cuda下载,torch下载网址:https://pytorch.org/get-started/locally/
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip3 install matplotlib
pip3 install kornia_moons
测试数据准备:
#数据集目录分布情况
--dataset_name
----/path/template # 模板图路径,可以有多张模板图
---------1.jpg
---------1.xlm #通过labelimg工具对1.jpg中想要匹配的目标标注bbox
---------2.jpg
---------2.xlm
---------等等
----/path/test # 测试图(待匹配的图)路径,可以有多张
---------1.jpg
---------2.jpg
---------3.jpg
---------4.jpg
---------等等
代码准备:
新建test.py,复制以下代码:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7'
import torch
use_cuda=torch.cuda.is_available()
device = torch.device('cuda:7' if use_cuda else 'cpu')
from pathlib import Path
from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image,rbd
from lightglue import viz2d
import random
import time
import cv2
import numpy as np
import xml.etree.ElementTree as ET
import argparse
from tqdm import tqdm
import re
colors=[[0,255,0],[255,0,0],[0,255,255],[255,0,255],[255,255,0],
[31, 119, 180],[255, 126, 14],[44, 160, 44],[214, 39, 40],[0,0,255],
[147, 102, 188],[140, 86, 75],[225, 119, 194],[126, 126, 126],
[188, 188, 34],[23, 188, 206]]
def draw_bboxes_on_image(image, bbox,id,okng,label_name):
"""在图像上绘制多个边界框,每个边界框对应一个掩码。
参数:
image: 输入图像。
bbox: 框。
"""
if bbox is not None:
x_min, y_min, x_max, y_max = bbox
while id>=len(colors):id=id-len(colors)
cv2.rectangle(image, (x_min, y_min), (x_max, y_max),colors[id] , 2)
# 定义文本及其位置
text = label_name
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.7
text_color = (0, 255, 0) if okng=='OK' else (0, 0, 255) # bgr
text_thickness = 2
# 计算文本尺寸以便调整位置
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, text_thickness)
text_org = (x_max - text_width, y_min + text_height) #右上角
# 在图像上添加文本
cv2.putText(image, text, text_org, font, font_scale, text_color, text_thickness, lineType=cv2.LINE_AA)
return image
def draw_circle(image, point_position,id):
# 定义点的颜色和大小
point_color =colors[id] # BGR格式,红色
point_radius = 5 # 点的半径
point_thickness = -1 # 填充圆形
# 在图像上绘制点
cv2.circle(image, point_position, point_radius, point_color, point_thickness)
return image
def numerically_sort(filenames):
"""使用正则表达式从文件名中提取数字,并进行排序"""
def extract_number(filename):
# 提取数字
match = re.match(r"(\d+)", filename)
return int(match.group(0)) if match else float('inf')
return sorted(filenames, key=extract_number)
def main(args):
os.makedirs(save_path,exist_ok=True)
# 设置匹配模式
# 特征点提取方法选择:DISK,SuperPoint。max_num_keypoints代表胰脏图片中提取的关键点的个数,可以自行设置
# extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
extractor = SuperPoint(max_num_keypoints=4096).eval().to(device) # load the extractor
# 匹配方法。features根据上述的模型调整:disk,superpoint
matcher = LightGlue(features='superpoint').eval().to(device) # load the matcher
# 获取模板图
fix_files = [
p for p in os.listdir(args.template_path)
if os.path.splitext(p)[-1] in [".jpg",'.JPG','.png']
]
fix_files.sort()
# 获取模板图的标注数据,通过labelimg工具,标注模板图中的roi的bbox,每个roi有可以有各自的名称
xlm_files = [
p for p in os.listdir(args.template_path)
if os.path.splitext(p)[-1] in [".xml"]
]
# 记录模板图中所有roi的标注名称
label_list=[]
for x in xlm_files:
root = ET.parse(os.path.join(args.template_path,x)).getroot()
for obj in (root.findall('object')):
name = obj.find('name').text
if name not in label_list:
label_list.append(name)
label_list.sort()
# 加载模板图片
fix_dic={}
for curfix in fix_files:
fix,fix_cv = load_image(os.path.join(args.template_path,curfix))
fix=fix.to(device)
fix_dic[curfix.split('.')[0]]=[fix,fix_cv]
# 【可选】我的匹配数据名称为阿拉伯数字1,2,3..,此处进行排序
moving_files = numerically_sort(os.listdir(args.test_path))
# 开始匹配
start_time=time.time()
for moving_name in tqdm(moving_files):
# 加载待匹配数据
moving,moving_cv = load_image(os.path.join(args.test_path,moving_name))
moving=moving.to(device)
feats0 = extractor.extract(moving) # auto-resize the image, disable with resize=None
# 从多张模板图中寻找最相似模板图
max_kpts_len,root=0,None
for k,v in fix_dic.items():
feats1 = extractor.extract(v[0])
# 两图匹配
matches01 = matcher({'image0': feats0, 'image1': feats1}) # feats0代表moving,feats1代表fix
rfeats0, rfeats1, rmatches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
kpts0, kpts1, matches = rfeats0["keypoints"], rfeats1["keypoints"], rmatches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]] # moving,fix
#如果匹配成功点数<阈值,则说明这两个图不是匹配图像对
if m_kpts0.shape[0]<args.matchNumThre:continue
# 获取最相似模板图的bbox信息
if max_kpts_len<m_kpts0.shape[0]:
max_m_kpts0,max_m_kpts1=m_kpts0,m_kpts1
max_kpts_len=m_kpts0.shape[0]
root = ET.parse(os.path.join(args.template_path,k+".xml")).getroot()
fix_bbox=v[1].copy()
w,h=1280,780 # 测试图像宽高
# 开始匹配
if root is not None:
new_m_kpts0,new_m_kpts1=[],[]
for obj in (root.findall('object')):
name = obj.find('name').text
# 形变关键点,将bbox中(x,y)映射到新图中
bndbox = obj.find('bndbox')
xmin = int(bndbox.find('xmin').text)
ymin = int(bndbox.find('ymin').text)
xmax = int(bndbox.find('xmax').text)
ymax = int(bndbox.find('ymax').text)
bbox_w,bbox_h=(xmax-xmin),(ymax-ymin)
custom_point = np.array([(xmax+xmin)//2,(ymax+ymin)//2]) #模板图中bbox的中心点
# 模板图中bbox的中心点与各个关键点的距离
distances = np.linalg.norm(max_m_kpts1.cpu() - custom_point, axis=1)
# 获取距离最短的关键点坐标作为bbox中心点的替代点
nearest_index = np.argmin(distances)
fix_cx,fix_cy = round(max_m_kpts1[nearest_index][0].item()),round(max_m_kpts1[nearest_index][1].item())
xmin,ymin,xmax,ymax=fix_cx-bbox_w//2,fix_cy-bbox_h//2,fix_cx+bbox_w//2,fix_cy+bbox_h//2
# 获取待匹配图像中的匹配点
new_cx,new_cy = round(max_m_kpts0[nearest_index][0].item()),round(max_m_kpts0[nearest_index][1].item())
if new_cx<0 or new_cx >=w or new_cy<0 or new_cy>=h:continueh
# 获取待匹配图像中的BBOX
new_xmin,new_ymin,new_xmax,new_ymax=new_cx-bbox_w//2,new_cy-bbox_h//2,new_cx+bbox_w//2,new_cy+bbox_h//2
# 保存结果图
if args.ifsaveimg:
if new_xmin<0:new_xmin=0
if new_ymin<0:new_ymin=0
if new_xmax>=w:new_xmax=w-1
if new_ymax>=h:new_ymax=h-1
new_m_kpts0.append(max_m_kpts0[nearest_index])
new_m_kpts1.append(max_m_kpts1[nearest_index])
# 绘制bbox中心点,每个roi有各自的颜色
fix_bbox=draw_circle(fix_bbox, (fix_cx,fix_cy), id=label_list.index(name))
moving_cv=draw_circle(moving_cv,(new_cx,new_cy), id=label_list.index(name))
# 绘制bbox
fix_bbox=draw_bboxes_on_image(fix_bbox, (xmin,ymin,xmax,ymax),id=label_list.index(name),label_name=name) #目标源 fix image
moving_cv=draw_bboxes_on_image(moving_cv, (new_xmin,new_ymin,new_xmax,new_ymax), id=label_list.index(name),label_name=name)
# 显示bbox和中心点
# viz2d.plot_images([moving_cv,fix_bbox],save_video_path,moving_name.split('.')[-2]+'_'+max_k+'.jpg')#两张图显示,看的不清楚
if args.ifsaveimg:
cv2.imwrite(os.path.join(save_video_path,moving_name),moving_cv)
end_time=time.time()
all_time = end_time-start_time
print(' all_time 运行时间:',all_time)
if args.ifsaveimg:
# 释放 VideoWriter 对象
video.release()
cv2.destroyAllWindows()
print(f"视频已保存到 {video_file}")
def Parser():
parser = argparse.ArgumentParser(description='模板关键点匹配')
parser.add_argument('--template_path', type=str, default="/path/template",help='模板图路径,可以有多张模板图')
parser.add_argument('--test_path', type=str, default="/path/test",help='待匹配图文件夹')
parser.add_argument('--save_path', type=str, default="/path/save_path",help='结果图保存文件夹')
parser.add_argument('--matchNumThre', type=int, default=500,help='匹配点数阈值')
parser.add_argument('--ifsaveimg', type=bool, default=True,help='是否保存图像')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = Parser()
main(args)
结果图展示
模板图:每个颜色的bbox对应一个实例
匹配结果图:和模板图中的目标对应。复用模板图中的bbox