【视觉算法—关键点匹配】LightGlue进阶应用—多图最优模板匹配

项目链接:

code:https://github.com/cvg/lightglue
Paper: https://arxiv.org/pdf/2306.13643.pdf

本文代码介绍:

  1. 模板图:有多张模板图,可能是不同场景的图,也可能是同一物体在3D各视角下拍摄的图
  2. 测试图:有N多张待匹配图像,无法人眼看出哪张模板图才是当前测试图的最佳匹配图
  3. 功能:
    a. 实现最佳模板图寻找策略
    b. 实现模板bbox复用在测试图中
    c. 实现示例级目标匹配
    d. 匹配结果可视化

本文代码流程:

  1. 利用lightglue寻找当前测试图像的最佳模板【既关键点匹配成功最多的图】
  2. 利用lightglue进行关键点匹配
  3. 寻找和bbox的中心点距离最近的关键点坐标
  4. 复用模板图中的bbox
  5. 测试图匹配结果可视化

准备工作:

代码下载

git clone https://github.com/cvg/LightGlue.git && cd LightGlue

环境配置

conda create -n lightglue python=3.11
python -m pip install --upgrade pip
python -m pip install -e . 

# 下面的根据自己需要下载,有就不用下了~
pip3 install opencv-python
pip3 install kornia
#根据自己的cuda下载,torch下载网址:https://pytorch.org/get-started/locally/
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip3 install matplotlib
pip3 install kornia_moons

测试数据准备:

#数据集目录分布情况
--dataset_name
----/path/template # 模板图路径,可以有多张模板图
---------1.jpg
---------1.xlm  #通过labelimg工具对1.jpg中想要匹配的目标标注bbox
---------2.jpg
---------2.xlm  
---------等等
----/path/test # 测试图(待匹配的图)路径,可以有多张
---------1.jpg
---------2.jpg
---------3.jpg
---------4.jpg
---------等等

代码准备:

新建test.py,复制以下代码:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7'
import torch
use_cuda=torch.cuda.is_available()
device = torch.device('cuda:7'  if use_cuda else 'cpu') 
from pathlib import Path
from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image,rbd
from lightglue import viz2d
import random
import time
import cv2
import numpy as np
import xml.etree.ElementTree as ET
import argparse
from tqdm import tqdm
import re

colors=[[0,255,0],[255,0,0],[0,255,255],[255,0,255],[255,255,0],
       [31, 119, 180],[255, 126, 14],[44, 160, 44],[214, 39, 40],[0,0,255],
        [147, 102, 188],[140, 86, 75],[225, 119, 194],[126, 126, 126],
        [188, 188, 34],[23, 188, 206]]


def draw_bboxes_on_image(image, bbox,id,okng,label_name):
    """在图像上绘制多个边界框,每个边界框对应一个掩码。
    参数:
    image: 输入图像。
    bbox: 框。
    """
    if bbox is not None:
        x_min, y_min, x_max, y_max = bbox
        while id>=len(colors):id=id-len(colors)
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max),colors[id] , 2)

        # 定义文本及其位置
        text = label_name
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.7
        text_color = (0, 255, 0) if okng=='OK' else (0, 0, 255) # bgr
        text_thickness = 2

        # 计算文本尺寸以便调整位置
        (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, text_thickness)
        text_org = (x_max - text_width, y_min + text_height) #右上角

        # 在图像上添加文本
        cv2.putText(image, text, text_org, font, font_scale, text_color, text_thickness, lineType=cv2.LINE_AA)

    return image

def draw_circle(image, point_position,id):
    # 定义点的颜色和大小
    point_color =colors[id] # BGR格式,红色
    point_radius = 5            # 点的半径
    point_thickness = -1        # 填充圆形

    # 在图像上绘制点
    cv2.circle(image, point_position, point_radius, point_color, point_thickness)
    return image


def numerically_sort(filenames):
    """使用正则表达式从文件名中提取数字,并进行排序"""
    def extract_number(filename):
        # 提取数字
        match = re.match(r"(\d+)", filename)
        return int(match.group(0)) if match else float('inf')
    
    return sorted(filenames, key=extract_number)

def main(args):
    os.makedirs(save_path,exist_ok=True)

    # 设置匹配模式
    # 特征点提取方法选择:DISK,SuperPoint。max_num_keypoints代表胰脏图片中提取的关键点的个数,可以自行设置
    # extractor = DISK(max_num_keypoints=2048).eval().cuda()  # load the extractor
    extractor = SuperPoint(max_num_keypoints=4096).eval().to(device)  # load the extractor
    
    # 匹配方法。features根据上述的模型调整:disk,superpoint
    matcher = LightGlue(features='superpoint').eval().to(device)  # load the matcher

    # 获取模板图
    fix_files = [
        p for p in os.listdir(args.template_path)
        if os.path.splitext(p)[-1] in [".jpg",'.JPG','.png']
    ]
    fix_files.sort()

    # 获取模板图的标注数据,通过labelimg工具,标注模板图中的roi的bbox,每个roi有可以有各自的名称
    xlm_files = [
        p for p in os.listdir(args.template_path)
        if os.path.splitext(p)[-1] in [".xml"]
    ]

    # 记录模板图中所有roi的标注名称
    label_list=[] 
    for x in xlm_files:
        root = ET.parse(os.path.join(args.template_path,x)).getroot()
        for obj in (root.findall('object')):
            name = obj.find('name').text
            if name not in label_list:
                label_list.append(name)
    label_list.sort()

    # 加载模板图片
    fix_dic={}
    for curfix in fix_files:
        fix,fix_cv = load_image(os.path.join(args.template_path,curfix))
        fix=fix.to(device)
        fix_dic[curfix.split('.')[0]]=[fix,fix_cv]

    # 【可选】我的匹配数据名称为阿拉伯数字1,2,3..,此处进行排序
    moving_files = numerically_sort(os.listdir(args.test_path))

    # 开始匹配
    start_time=time.time()
    for moving_name in tqdm(moving_files):

        # 加载待匹配数据
        moving,moving_cv = load_image(os.path.join(args.test_path,moving_name))
        moving=moving.to(device)
        feats0 = extractor.extract(moving)  # auto-resize the image, disable with resize=None

        # 从多张模板图中寻找最相似模板图
        max_kpts_len,root=0,None
        for k,v in fix_dic.items():
            feats1 = extractor.extract(v[0])

            # 两图匹配
            matches01 = matcher({'image0': feats0, 'image1': feats1}) # feats0代表moving,feats1代表fix
            rfeats0, rfeats1, rmatches01 = [rbd(x) for x in [feats0, feats1, matches01]]  # remove batch dimension
            kpts0, kpts1, matches = rfeats0["keypoints"], rfeats1["keypoints"], rmatches01["matches"]
            m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]] # moving,fix

            #如果匹配成功点数<阈值,则说明这两个图不是匹配图像对
            if m_kpts0.shape[0]<args.matchNumThre:continue 

            # 获取最相似模板图的bbox信息
            if max_kpts_len<m_kpts0.shape[0]:
                max_m_kpts0,max_m_kpts1=m_kpts0,m_kpts1
                max_kpts_len=m_kpts0.shape[0]
                root = ET.parse(os.path.join(args.template_path,k+".xml")).getroot()
                fix_bbox=v[1].copy() 
        
        w,h=1280,780 # 测试图像宽高
        # 开始匹配
        if root is not None:
            new_m_kpts0,new_m_kpts1=[],[]
            for obj in (root.findall('object')):
                name = obj.find('name').text

                # 形变关键点,将bbox中(x,y)映射到新图中
                bndbox = obj.find('bndbox')
                xmin = int(bndbox.find('xmin').text)
                ymin = int(bndbox.find('ymin').text)
                xmax = int(bndbox.find('xmax').text)
                ymax = int(bndbox.find('ymax').text)
                bbox_w,bbox_h=(xmax-xmin),(ymax-ymin)
                custom_point = np.array([(xmax+xmin)//2,(ymax+ymin)//2]) #模板图中bbox的中心点

                # 模板图中bbox的中心点与各个关键点的距离
                distances = np.linalg.norm(max_m_kpts1.cpu() - custom_point, axis=1) 

                # 获取距离最短的关键点坐标作为bbox中心点的替代点
                nearest_index = np.argmin(distances) 
                fix_cx,fix_cy = round(max_m_kpts1[nearest_index][0].item()),round(max_m_kpts1[nearest_index][1].item())
                xmin,ymin,xmax,ymax=fix_cx-bbox_w//2,fix_cy-bbox_h//2,fix_cx+bbox_w//2,fix_cy+bbox_h//2

                # 获取待匹配图像中的匹配点
                new_cx,new_cy = round(max_m_kpts0[nearest_index][0].item()),round(max_m_kpts0[nearest_index][1].item()) 
                if new_cx<0 or new_cx >=w or new_cy<0 or new_cy>=h:continueh
                # 获取待匹配图像中的BBOX
                new_xmin,new_ymin,new_xmax,new_ymax=new_cx-bbox_w//2,new_cy-bbox_h//2,new_cx+bbox_w//2,new_cy+bbox_h//2
                
                #  保存结果图
                if args.ifsaveimg:
                    if new_xmin<0:new_xmin=0
                    if new_ymin<0:new_ymin=0
                    if new_xmax>=w:new_xmax=w-1 
                    if new_ymax>=h:new_ymax=h-1
                    new_m_kpts0.append(max_m_kpts0[nearest_index])
                    new_m_kpts1.append(max_m_kpts1[nearest_index])

                    # 绘制bbox中心点,每个roi有各自的颜色
                    fix_bbox=draw_circle(fix_bbox, (fix_cx,fix_cy), id=label_list.index(name))
                    moving_cv=draw_circle(moving_cv,(new_cx,new_cy), id=label_list.index(name))
                    # 绘制bbox
                    fix_bbox=draw_bboxes_on_image(fix_bbox, (xmin,ymin,xmax,ymax),id=label_list.index(name),label_name=name) #目标源 fix image
                    moving_cv=draw_bboxes_on_image(moving_cv, (new_xmin,new_ymin,new_xmax,new_ymax), id=label_list.index(name),label_name=name)
            
            # 显示bbox和中心点
            # viz2d.plot_images([moving_cv,fix_bbox],save_video_path,moving_name.split('.')[-2]+'_'+max_k+'.jpg')#两张图显示,看的不清楚
         
        if args.ifsaveimg:
            cv2.imwrite(os.path.join(save_video_path,moving_name),moving_cv)


    end_time=time.time()
    all_time = end_time-start_time
    print(' all_time 运行时间:',all_time)

    
    if args.ifsaveimg:
        # 释放 VideoWriter 对象
        video.release()
        cv2.destroyAllWindows()
        print(f"视频已保存到 {video_file}")


def Parser():
    parser = argparse.ArgumentParser(description='模板关键点匹配')
    parser.add_argument('--template_path', type=str, default="/path/template",help='模板图路径,可以有多张模板图')
    parser.add_argument('--test_path', type=str, default="/path/test",help='待匹配图文件夹')
    parser.add_argument('--save_path', type=str, default="/path/save_path",help='结果图保存文件夹')
    parser.add_argument('--matchNumThre', type=int, default=500,help='匹配点数阈值')
    parser.add_argument('--ifsaveimg', type=bool, default=True,help='是否保存图像')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = Parser()
    main(args)
    

结果图展示

模板图:每个颜色的bbox对应一个实例
在这里插入图片描述
匹配结果图:和模板图中的目标对应。复用模板图中的bbox
在这里插入图片描述

### 关于 LightGlue 的使用教程和项目地址 #### 项目地址 LightGlue 是一个开源项目,其官方 GitHub 地址可以通过以下命令获取并克隆到本地环境中: ```bash git clone https://github.com/cvg/LightGlue.git cd LightGlue ``` 此操作允许开发者下载项目的最新版本以便进一步研究或开发[^1]。 对于国内用户而言,如果访问 GitHub 存在困难,可以考虑通过镜像站点获取该项目资源。例如,在 GitCode 上提供了相应的镜像支持,具体地址如下: - 镜像地址: [https://gitcode.com/gh_mirrors/li/LightGlue](https://gitcode.com/gh_mirrors/li/LightGlue)[^2] #### 安装与配置指南 为了更好地运行 LightGlue 项目,建议按照以下流程完成初始化设置: ##### 虚拟环境搭建(可选) 虽然并非强制要求,但在独立的 Python 环境下工作有助于减少依赖冲突以及保持系统的稳定性。因此推荐执行以下指令来创建虚拟环境: ```bash python -m venv lightglue_env source lightglue_env/bin/activate # Linux/MacOS lightglue_env\Scripts\activate # Windows ``` 随后依据 `requirements.txt` 文件安装必要的库文件: ```bash pip install -r requirements.txt ``` 以上步骤能够确保所有必需组件被正确加载至当前环境下。 #### 推荐扩展阅读——ONNX 版本实现 除了基础版外,还有基于 ONNX 运行时优化过的 C++ 实现方案可供选择,即 **LightGlue-OnnxRunner** 。该分支专注于提供高性能推理服务,并兼容种特征提取器如 SuperPoint 和 DISK 结合使用的场景分析功能。同样地,我们也可以轻松将其拉取下来进行探索学习: ```bash git clone https://github.com/OroChippw/LightGlue-OnnxRunner.git cd LightGlue-OnnxRunner ``` 或者利用国内镜像链接替代原始 URL 来加速下载过程: - 国内镜像地址:[https://gitcode.com/gh_mirrors/li/LightGlue-OnnxRunner](https://gitcode.com/gh_mirrors/li/LightGlue-OnnxRunner)[^4] 更细节描述请参阅对应文档说明部分以获得最佳实践指导信息。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值