EfficientSAM分割对象,并计算其高度

EfficientSAM.py

import cv2
import FindHight

from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam

from PIL import Image
from torchvision import transforms
import torch
import numpy as np
import zipfile


def SAM_object(image_path):

    models = {}

    # Build the EfficientSAM-S model.
    models['efficientsam_s'] = build_efficient_sam_vits()

    # # Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
    # with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
    #     zip_ref.extractall("weights")
    # # Build the EfficientSAM-S model.
    # models['efficientsam_s'] = build_efficient_sam_vits()

    # load an image
    sample_image_np = np.array(Image.open(image_path))
    sample_image_tensor = transforms.ToTensor()(sample_image_np)
    # Feed a few (x,y) points in the mask as input.

    input_points = torch.tensor([[[[580, 350], [650, 350]]]])
    input_labels = torch.tensor([[[1, 1]]])

    # Run inference for both EfficientSAM-Ti and EfficientSAM-S models.

    for model_name, model in models.items():
        print('Running inference using ', model_name)
        predicted_logits, predicted_iou = model(
            sample_image_tensor[None, ...],
            input_points,
            input_labels,
        )
        sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
        predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
        predicted_logits = torch.take_along_dim(
            predicted_logits, sorted_ids[..., None, None], dim=2
        )
        # The masks are already sorted by their predicted IOUs.
        # The first dimension is the batch size (we have a single image. so it is 1).
        # The second dimension is the number of masks we want to generate (in this case, it is only 1)
        # The third dimension is the number of candidate masks output by the model.
        # For this demo we use the first mask.
        mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
        masked_image_np = sample_image_np.copy().astype(np.uint8) * mask[:,:,None]

        Image.fromarray(masked_image_np).save(f"figs/examples/sam_mask.png")
        # Image.fromarray(masked_image_np).show("sam_mask")
        # 将PIL图片转化为opencv格式
        opencv_image = cv2.cvtColor(np.array(masked_image_np), cv2.COLOR_RGB2BGR)
        # cv2.imshow("sam", opencv_image)
        # cv2.waitKey()
        return opencv_image

if __name__ == '__main__':
    image_path = "figs/examples/part02.jpg"

    highest_point = (0, 0)
    real_highest_y = 300

    sam_image = SAM_object(image_path)

    if sam_image is None:
        print("no sam_segmentaton image")
    else:
        FindHight.cal_pixl_hight_image(sam_image)

FindHight.py

import os
import cv2

def read_images(folder_path):
    image_files = [f for f in os.listdir(folder_path) if
                   f.endswith(".jpg") or f.endswith(".png")]  # 获取文件夹中所有的jpg和png图片文件
    sorted_image_files = sorted(image_files, key=lambda x: int(''.join(filter(str.isdigit, os.path.splitext(x)[0]))))

    images = []  # 用来存储读取的图像
    for filename in sorted_image_files:
        # print(filename)
        image_path = os.path.join(folder_path, filename)  # 构建完整的文件路径
        image = cv2.imread(image_path)  # 使用OpenCV库读取图像文件
        # image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR)

        images.append(image)  # 将图像添加到列表中

    return images

def reprocess_image(img):
    # 读入图片
    # for i in range(len(images)):
    # img = images[i]
    # 显示图片
    # 将彩色图片转换成灰度图像
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 二值化处理
    # ret, thresh = cv2.threshold(gray, 50, 255, 0)
    # cv2.imshow("Binary Image", thresh)
    # cv2.waitKey(0)

    contour_image = img.copy()
    # 直接用灰度图寻找轮廓
    contours, hierarchy = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # 二值化处理后的图像再处理轮廓
    # contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 3)
    cv2.imshow("contour_image", contour_image)
    cv2.waitKey(0)
    return contours, img

def find_highest_y_point(contours, img):
    _highest_point = (0, 0)
    max_y = img.shape[0]
    for contour in contours:
        for point in contour:
            if point[0][1] < max_y:
                max_y = point[0][1]
                _highest_point = (point[0][0], point[0][1])
                # print(highest_point[0])
                # print(highest_point[1])

    # 在图像上标记最高点
    cv2.circle(img, _highest_point, 5, (0, 0, 255), -1)
    # 显示结果
    # cv2.imshow('Highest Point', img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()
    return _highest_point


def cal_pixl_hight_folders(folder_path):
    if(read_images(folder_path) is not None):
        print("读取成功")
        print(len(read_images(folder_path)))
        for _image in read_images(folder_path):
            contours, img = reprocess_image(_image)
            highest_point = find_highest_y_point(contours, img)
            real_highest_y = img.shape[0] - highest_point[1]
            real_highest_y *= 0.25
            # print("最高点坐标:", highest_point)
            # print("最高点x坐标:", highest_point[0])
            # print("最高点y坐标:", highest_point[1])
            print(real_highest_y)
    else:
        print("读取失败")

def cal_pixl_hight_image(_image):
    print("Caluating hight point")
    contours, img = reprocess_image(_image)
    highest_point = find_highest_y_point(contours, img)
    real_highest_y = img.shape[0] - highest_point[1]
    real_highest_y *= 0.25
    # print("最高点坐标:", highest_point)
    # print("最高点x坐标:", highest_point[0])
    # print("最高点y坐标:", highest_point[1])
    print(real_highest_y)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    font_color = (0, 255, 255)  # 文本颜色,以BGR格式指定
    line_type = 2
    cv2.putText(img, "hight is " + str(real_highest_y) + "mm", (100, 100), font, font_scale, font_color, lineType=line_type)
    cv2.imshow('Point', img)
    cv2.waitKey()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    # 输入sam分割后的图像
    path_path = "figs/examples/sam_mask.png"  # 替换为你的文件夹路径
    highest_point = (0, 0)
    real_highest_y = 300
    image = cv2.imread(path_path)
    cal_pixl_hight_image(image)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值