基于向量数据库的分类模型类别扩充方法(以图搜图)

本文不生产技术,只做技术的搬运工!!!

前言

        最近公司甲方提了新需求,要求可以用少量的样例图像进行类别扩充,加快迭代速度,刚好最近在倒腾以图搜图,就以向量数据库的形式进行了测试实验,效果还不错,这里做下记录。

背景描述

        当前分类模型已经有8个类别,且经过了大量数据的训练,加方要求后续再新增类别不需要大批量采集数据,只需要提供一定数量的样例图像即可完成分类。

技术路线

我们项目使用的是yolov8-cls进行的分类任务,这里我们需要对onnx模型进行一些改动,获取其中间输出作为特征向量,然后存入数据库中

源码

模型修改

import onnx
import numpy as np
from onnx import helper, checker

# 加载现有的 ONNX 模型
model = onnx.load("/home/workspace/temp/last.onnx")

# 找到你想作为新输出的节点
target_layer_name = "/model.9/Flatten_output_0"  # 替换为目标层的名字

# 创建新的输出信息
new_output_info = helper.ValueInfoProto()
new_output_info.name = target_layer_name

# 设置 type 字段
tensor_type_proto = new_output_info.type.tensor_type
tensor_type_proto.elem_type = onnx.TensorProto.FLOAT  # 假设输出类型为 float
tensor_shape_proto = tensor_type_proto.shape

# 假设输出形状为 (batch_size, num_features),这里可以具体化 batch_size 和 num_features
# 这里使用 -1 表示动态维度
tensor_shape_proto.dim.add().dim_value = -1  # batch_size
tensor_shape_proto.dim.add().dim_value = -1  # num_features

# 将新的输出信息添加到模型的输出列表中
model.graph.output.append(new_output_info)

# 验证修改后的模型是否有效
checker.check_model(model)

# 保存修改后的模型
onnx.save(model, "/home/workspace/temp/last_add.onnx")

修改后的模型输出头如下图所示

 数据入库

from PIL import Image
import onnxruntime
import torch
import os
import numpy as np
import cv2
from torchvision import transforms
from pymilvus import MilvusClient
if os.path.isfile("val.db"):
    client = MilvusClient("val.db")
else:
    client = MilvusClient("val.db")
    if client.has_collection(collection_name="text_image"):
        client.drop_collection(collection_name="text_image")
    client.create_collection(
        collection_name="text_image",
        dimension=1280,  # The vectors we will use in this demo has 768 dimensions
        metric_type="COSINE"
    )

def read_image(image_path):
    """
    读取图像并应用预处理。

    :param image_path: 图像的路径。
    :return: 经过预处理的图像和原始图像。
    """
    # 从指定路径读取图像,解决中文路径问题
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    # 将图像从 BGR 转换为 RGB 格式
    img = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)

    # 使用 InterpolationMode.BILINEAR 指定双线性插值
    transform = transforms.Compose([
        # 将图像大小调整为 224x224,保持宽高比不变
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True),
        # 从图像中心裁剪出 224x224 的区域
        transforms.CenterCrop(size=(224, 224)),
        # 将图像转换为 PyTorch 张量
        transforms.ToTensor(),
        # 使用指定的均值和标准差对图像进行标准化
        transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    ])
    # 将图像转换为 PIL 图像并应用变换
    pil_image = Image.fromarray(img)
    normalized_image = transform(pil_image)

    # 添加维度以适应模型输入要求,并返回处理后的图像和原始图像
    return np.expand_dims(normalized_image.numpy(), axis=0), src

def getFileList(dir, Filelist, ext=None):
    """
    获取文件夹及其子文件夹中文件列表
    输入 dir:文件夹根目录
    输入 ext: 扩展名
    返回: 文件路径列表
    """
    newDir = dir
    if os.path.isfile(dir):
        if ext is None:
            Filelist.append(dir)
        else:
            if ext in dir:
                Filelist.append(dir)
    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            newDir = os.path.join(dir, s)
            getFileList(newDir, Filelist, ext)
    return Filelist

def load_onnx_model(model_path):
    providers = ['CUDAExecutionProvider']
    session = onnxruntime.InferenceSession(model_path, providers=providers)
    print("ONNX模型已成功加载。")
    return session

def image_tensor(image_path, session):
    image,_ = read_image(image_path)
    input_name = session.get_inputs()[0].name
    output_name = "/model.9/Flatten_output_0"
    pred = session.run([output_name], {input_name: image})[0]
    pred = np.squeeze(pred)
    return pred

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_path = "/home/project_python/onnx_image-image/last_add.onnx"
    img_dir = r"/home/project_python/onnx_image-image/test_image/add"
    session = load_onnx_model(model_path)
    image_path_list = []
    image_path_list = getFileList(img_dir, image_path_list, '.jpg')
    data = []
    i = 0
    for image_path in image_path_list:
        temp = {}
        image = Image.open(image_path)
        image_features = image_tensor(image_path, session)
        #image_features = image_features / image_features.norm(dim=-1, keepdim=True)  # normalize
        # 将特征向量转换为字符串
        temp['id'] = i
        temp['image_path'] = image_path
        temp['vector'] = image_features
        temp['image_labels'] = image_path.split("/")[-2]
        data.append(temp)
        i = i + 1
        print(i)
    res = client.insert(collection_name="text_image", data=data)

初始化数据库和新增数据都可以使用该代码,需要注意的是,数据库存入的类别信息是以图像所在文件夹的名称决定的,因此需要提前修改好文件夹名称,另外新增的数据最好是和分类模型原始类别同属性的类别,比如原来分类模型是猫狗分类,可以新增鸭子的类别,这样效果会比较好。

数据查询

import onnxruntime
import torch
import numpy as np
import cv2
from torchvision import transforms
from pymilvus import MilvusClient
from PIL import Image, ImageDraw, ImageFont
client = MilvusClient("val.db")

def display_images_in_grid(image_paths, scores, labels, images_per_row=3):
    # 检查输入长度是否一致
    if not (len(image_paths) == len(scores) == len(labels)):
        raise ValueError("The lengths of image_paths, scores, and labels must be the same.")

    # 计算需要的行数
    num_images = len(image_paths)
    num_rows = (num_images + images_per_row - 1) // images_per_row

    # 打开所有图像并调整大小
    images = []
    for path in image_paths:
        with Image.open(path) as img:
            img = img.resize((200, 200))  # 调整图像大小以适应画布
            images.append(img)

    # 创建一个空白画布
    canvas_width = images_per_row * 200
    canvas_height = num_rows * 200
    canvas = Image.new('RGB', (canvas_width, canvas_height), (255, 255, 255))
    draw = ImageDraw.Draw(canvas)

    # 加载字体(这里假设你有一个可用的字体文件)
    try:
        font = ImageFont.truetype("arial.ttf", 40)
    except IOError:
        font = ImageFont.load_default()

    # 将图像粘贴到画布上,并绘制分数和标签
    for idx, img in enumerate(images):
        row = idx // images_per_row
        col = idx % images_per_row
        position = (col * 200, row * 200)
        canvas.paste(img, position)

        # 绘制分数
        score_text = f"Score: {scores[idx]:.2f}"
        text_position_score = (position[0] + 10, position[1] + 10)
        draw.text(text_position_score, score_text, font=font, fill=(255, 0, 0))

        # 绘制标签
        label_text = f"Label: {labels[idx]}"
        text_position_label = (position[0] + 10, position[1] + 20)  # 调整位置以避免重叠
        draw.text(text_position_label, label_text, font=font, fill=(255, 0, 0))

    # 显示画布
    canvas.show()



def read_image(image_path):
    """
    读取图像并应用预处理。

    :param image_path: 图像的路径。
    :return: 经过预处理的图像和原始图像。
    """
    # 从指定路径读取图像,解决中文路径问题
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    # 将图像从 BGR 转换为 RGB 格式
    img = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)

    # 使用 InterpolationMode.BILINEAR 指定双线性插值
    transform = transforms.Compose([
        # 将图像大小调整为 224x224,保持宽高比不变
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True),
        # 从图像中心裁剪出 224x224 的区域
        transforms.CenterCrop(size=(224, 224)),
        # 将图像转换为 PyTorch 张量
        transforms.ToTensor(),
        # 使用指定的均值和标准差对图像进行标准化
        transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    ])
    # 将图像转换为 PIL 图像并应用变换
    pil_image = Image.fromarray(img)
    normalized_image = transform(pil_image)

    # 添加维度以适应模型输入要求,并返回处理后的图像和原始图像
    return np.expand_dims(normalized_image.numpy(), axis=0), src

def load_onnx_model(model_path):
    providers = ['CUDAExecutionProvider']
    session = onnxruntime.InferenceSession(model_path, providers=providers)
    print("ONNX模型已成功加载。")
    return session

def image_tensor(image_path, session):
    image,_ = read_image(image_path)
    input_name = session.get_inputs()[0].name
    output_name = "/model.9/Flatten_output_0"
    pred = session.run([output_name], {input_name: image})[0]
    #pred = np.squeeze(pred)
    return pred

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_path = "/home/project_python/onnx_image-image/last_add.onnx"
    img_path = r"/home/project_python/onnx_image-image/test_image/box_11.jpg"
    session = load_onnx_model(model_path)
    image_features = image_tensor(img_path, session)
    results = client.search(
        "text_image",
        data=image_features,
        output_fields=["image_path"],
        search_params={"metric_type": "COSINE"},
        limit=7
    )
    image_list = []
    score_list = []
    label_list = []
    for i,result in enumerate(results[0]):
        image_list.append(result["entity"]["image_path"])
        score_list.append(result["distance"])
        label_list.append(result["entity"]["image_path"].split("/")[-2])
    display_images_in_grid(image_list, score_list,label_list,3)
    #print(results)
### FAISS 向量数据库实现以图搜图功能 FAISS 是一种高效的相似度搜索库,特别适合处理大规模向量数据集。以下是关于如何使用 FAISS 实现基于图像的搜索功能的相关说明。 #### 1. 数据预处理 为了实现以图搜图的功能,首先需要将图像转换为数值化的特征向量表示形式。这通常通过深度学习模型完成,例如卷积神经网络 (CNN),用于提取每张图片的关键特征[^2]。这些特征可以作为输入传递给 FAISS 进行索引构建和查询操作。 ```python import numpy as np from PIL import Image from torchvision import transforms, models import torch def extract_features(image_path): model = models.resnet50(pretrained=True).eval() # 使用 ResNet 提取特征 preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = Image.open(image_path) input_tensor = preprocess(img) input_batch = input_tensor.unsqueeze(0) with torch.no_grad(): features = model(input_batch) return features.numpy().flatten() ``` #### 2. 构建 FAISS 索引 一旦获得了所有图像对应的特征向量集合,就可以利用 FAISS 创建一个高效的数据结构来支持快速近似最近邻查找。这里展示了一个简单的例子,其中选择了 FlatL2 和 IVFADC 方法分别代表精确匹配以及更优性能但牺牲一定精度的情况: ```python import faiss vectors = [] # 假设这是之前计算得到的所有图片特征列表 dim = vectors[0].shape[-1] index_flat = faiss.IndexFlatL2(dim) # 平面 L2 距离索引 print(f"Index size before adding elements: {index_flat.ntotal}") # 将所有的向量加入到索引中 for vec in vectors: index_flat.add(np.array([vec])) print(f"Index size after adding all elements: {index_flat.ntotal}") ``` 对于更大规模的应用场景下推荐采用如下方式初始化索引对象: ```python nlist = 100 # 定义聚类中心数量 quantizer = faiss.IndexFlatL2(dim) index_ivfadc = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_L2) assert not index_ivfadc.is_trained index_ivfadc.train(vectors[:]) # 训练阶段仅需部分样本即可 assert index_ivfadc.is_trained index_ivfadc.add(vectors[:]) ``` #### 3. 查询最相似项 最后一步就是针对新上传的一幅目标图画作同样的特征抽取过程,并调用 `search` 函数找到与其最为接近的结果们。 ```python query_vector = extract_features('path_to_query_image.jpg') k = 4 # 返回前 k 个最佳匹配结果 distances, indices = index_flat.search(query_vector.reshape((1,-1)), k) print("Distances:", distances) print("Indices:", indices) ``` 以上即完成了整个流程概述——从原始素材准备直至最终检索呈现环节逐一解析完毕。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔障阿Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值