YOLO-seg跑透明物体TROSD数据集,并出miou指标

1.TROSD数据集介绍

在这里插入图片描述

数据集官方地址:
Google云盘下载 链接:

对于目标对象,TROSD考虑不同尺寸、形状和颜色的常见类型(主要包括玻璃制品、镜子、塑料和金属)。对于场景,TROSD数据集包含14个独特的场景(例如客厅、浴室、办公室等)。最后,官方使用Labelme手动注释这些图像。

1.TROSD数据集包含11060幅RGB-D图像(7421幅用于训练,3639幅用于测试)。

2.由于我们实现的RGB-D相机的工作机制,深度图像中存在一些噪声。深度保存为.png图像。

3.在语义掩码中,0、1、2分别表示背景、透明对象和反射对象。

2.语义分割mask掩码转化为labelme格式(json文件)

由于官方只提供了RGB、mask和depth 的图像,没有分割的json文件。因此,我们要跑通Yolo-seg,需要自己转换处理。

# 导入包
import os
import io
import json
import numpy as np
from pycococreatortools import pycococreatortools
from PIL import Image
import base64

def img_tobyte(img_pil):
    '''
    该函数用于将图像转化为base64字符类型
    :param img_pil: Image类型
    :return base64_string: 字符串
    '''
    ENCODING = 'utf-8'
    img_byte = io.BytesIO()
    img_pil.save(img_byte, format='PNG')
    binary_str2 = img_byte.getvalue()
    imageData = base64.b64encode(binary_str2)
    base64_string = imageData.decode(ENCODING)
    return base64_string

# 定义路径
ROOT_DIR = '' # 请输入你文件的根目录
Image_DIR = os.path.join(ROOT_DIR, "Image") # 目录底下包含图片
Label_DIR = os.path.join(ROOT_DIR, "GT") # 目录底下包含label文件
# 读取路径下的掩码
Label_files = os.listdir(Label_DIR)
# 指定png中index中对应的label
class_names = ['_background_', 'basketball', 'person'] # 分别表示label标注图中1对应basketball,2对应person。
for Label_filename in Label_files:
    # 创建一个json文件
    Json_output = {
        "version": "3.16.7",
        "flags": {},
        "fillColor": [255, 0, 0, 128],
        "lineColor": [0, 255, 0, 128],
        "imagePath": {},
        "shapes": [],
        "imageData": {}}
    print(Label_filename)
    name = Label_filename.split('.', 3)[0]
    name1 = name + '.jpg'
    Json_output["imagePath"] = name1
    # 打开原图并将其转化为labelme json格式
    image = Image.open(Image_DIR + '/' + name1)
    imageData = img_tobyte(image)
    Json_output["imageData"] = imageData
    # 获得注释的掩码
    binary_mask = np.asarray(np.array(Image.open(Label_DIR + '/' + Label_filename))
                             ).astype(np.uint8)
    # 分别对掩码中的label结果绘制边界点
    for i in np.unique(binary_mask):
        if i != 0:
            temp_mask = np.where(binary_mask == i, 1, 0)
            segmentation = pycococreatortools.binary_mask_to_polygon(temp_mask, tolerance=2) # tolerancec参数控制无误差
            for item in segmentation:
                if (len(item) > 10):
                    list1 = []
                    for j in range(0, len(item), 2):
                        list1.append([item[j], item[j + 1]])
                    label = class_names[i]  #
                    seg_info = {'points': list1, "fill_color": None, "line_color": None, "label": label,
                                "shape_type": "polygon", "flags": {}}
                    Json_output["shapes"].append(seg_info)
    Json_output["imageHeight"] = binary_mask.shape[0]
    Json_output["imageWidth"] = binary_mask.shape[1]
    # 保存在根目录下的json文件中
    full_path = '{}/json/' + name + '.json'
    with open(full_path.format(ROOT_DIR), 'w') as output_json_file:
        json.dump(Json_output, output_json_file)

还需要下载pycococreatortools包,Git链接。将红框内容放到自己的代码根路径下。
在这里插入图片描述
验证转的对不对,可以用如下代码(根据指定的json文件将label可视化在image上):

# 根据指定的json文件将label可视化在image上
import json
import cv2
import numpy as np

# 读取JSON文件
json_file_path = 'your_json_file.json'  # 替换为你的JSON文件路径
with open(json_file_path, 'r') as f:
    data = json.load(f)

# 读取图像
image_path = data['imagePath']
image = cv2.imread(image_path)

# 获取标签的形状和位置
shapes = data['shapes']

# 可视化标签
for shape in shapes:
    points = np.array(shape['points'], dtype=np.int32)
    label = shape['label']
    
    # 绘制多边形
    cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
    
    # 在标签区域绘制标签文本
    x, y = points[0]
    cv2.putText(image, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2)

# 显示图像
cv2.imshow('Image with Labels', image)

# 等待按键并关闭窗口
cv2.waitKey(0)
cv2.destroyAllWindows()

3.json文件转txt格式

# -*- coding: utf-8 -*-
import mask2json
import os
import argparse
from tqdm import tqdm


def convert_label_json(json_dir, save_dir, class_list):
    json_paths = os.listdir(json_dir)
    classes = class_list
    for json_path in tqdm(json_paths):
        # for json_path in json_paths:
        path = os.path.join(json_dir, json_path)
        with open(path, 'r', encoding='UTF-8') as load_f:
            json_dict = json.load(load_f)
        h, w = json_dict['imageHeight'], json_dict['imageWidth']

        # save json path
        txt_path = os.path.join(save_dir, json_path.replace('json', 'json'))
        txt_file = open(txt_path, 'w', encoding='UTF-8')

        for shape_dict in json_dict['shapes']:
            label = shape_dict['label']
            label_index = classes.index(label)
            points = shape_dict['points']

            points_nor_list = []

            for point in points:
                points_nor_list.append(point[0] / w)
                points_nor_list.append(point[1] / h)

            points_nor_list = list(map(lambda x: str(x), points_nor_list))
            points_nor_str = ' '.join(points_nor_list)

            label_str = str(label_index) + ' ' + points_nor_str + '\n'
            txt_file.writelines(label_str)


if __name__ == "__main__":
    """
    python json2txt_nomalize.py --json-dir my_datasets/color_rings/jsons --save-dir my_datasets/color_rings/txts --classes "cat,dogs"
    """
    parser = argparse.ArgumentParser(description='json convert to json params')
    parser.add_argument('--json-dir', type=str, default=r'D:\bishe\TROSD\train\json', help='json path dir')
    parser.add_argument('--save-dir', type=str, default=r'D:\bishe\ultralytics-main\trosd\labels\train', help='json save dir')
    parser.add_argument('--class-list', type=str, default=['transparent', 'reflective'], help='classes')
    args = parser.parse_args()
    json_dir = args.json_dir
    save_dir = args.save_dir
    class_list = args.class_list
    convert_label_json(json_dir, save_dir, class_list)

4.Yolo-seg分割预测可视化

'''可视化'''
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Load the trained YOLO segmentation model
model = YOLO(r'D:\bishe\ultralytics-main\runs\segment\train\weights\best.pt')

# Path to the input image
image_path = r"D:\bishe\ultralytics-main\trosd\images\val\cg_real_test_d415_000000000_1.jpg"
output_path = r"output_image.jpg"
# Class names dictionary
yolo_train_ids = {
    'transparent': 0, 'reflective': 1,
}
id_to_class = {v: k for k, v in yolo_train_ids.items()}

# Assign a fixed color to each class
np.random.seed(42)
# class_colors = {cls_id: tuple(np.random.randint(0, 255, 3).tolist()) for cls_id in id_to_class}
class_colors = {0: (0, 255, 0), 1: (0, 0, 255)}

# Perform inference
results = model(image_path)

# Extract predictions
predictions = results[0]
masks = predictions.masks
boxes = predictions.boxes
image = cv2.imread(image_path)

# Create a blank canvas to overlay masks
mask_canvas = np.zeros_like(image, dtype=np.uint8)

# Iterate over masks and draw them on the canvas
for mask, box in zip(masks.data, boxes):
    cls = int(box.cls[0])
    class_name = id_to_class.get(cls, "Unknown")
    color = class_colors[cls]

    # Convert mask from float to binary and resize to the image dimensions
    binary_mask = mask.cpu().numpy().astype(np.uint8)
    binary_mask_resized = cv2.resize(binary_mask, (image.shape[1], image.shape[0]))

    # Apply color to the mask and overlay on the canvas
    colored_mask = cv2.merge((binary_mask_resized * color[0],
                              binary_mask_resized * color[1],
                              binary_mask_resized * color[2]))
    mask_canvas = cv2.addWeighted(mask_canvas, 1, colored_mask, 0.5, 0)

    # Add class name text to the image
    x1, y1, x2, y2 = [int(coord) for coord in box.xyxy[0]]
    text_x = max(x1, 0)
    text_y = max(y1 - 10, 0)
    cv2.putText(image, class_name, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)

# Overlay the segmentation mask on the original image
segmented_image = cv2.addWeighted(image, 0.5, mask_canvas, 0.8, 0)

# Save and display the output image
# cv2.imwrite(output_path, segmented_image)
plt.imshow(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.show()

5.保存预测结果为mask

from ultralytics import YOLO
import cv2
import torch
import numpy as np
import os
if __name__ == '__main__':
    model = YOLO(r'D:\bishe\ultralytics-main\runs\segment\train\weights\best.pt')
    img_dir = r'D:\bishe\TROSD\test\rgb'
    os.makedirs(img_dir, exist_ok=True)
    pred_image_dir = r'D:\bishe\TROSD\test\predict_mask'

    # Class names dictionary
    yolo_train_ids = {
        'transparent': 0, 'reflective': 1,
    }
    id_to_class = {v: k for k, v in yolo_train_ids.items()}


    for img_path in os.listdir(img_dir):
        curr_img_path = os.path.join(img_dir, img_path)

        results = model(curr_img_path, save=False, save_crop=False,
                        # imgsz=640, classes=[0], max_det=1, show_labels=False, show_conf=False,
                        imgsz=640, show_labels=False, show_conf=False,
                        show_boxes=False)

        result = results[0]
        # Extract predictions
        predictions = results[0]
        masks = predictions.masks
        boxes = predictions.boxes

        if result.masks is not None and len(result.masks) > 0:
            # 创建一个全零的掩码,尺寸与原始预测掩码相同
            mask_height, mask_width = masks.shape[1:3]
            final_mask = np.zeros((mask_height, mask_width), dtype=np.uint8)
            for mask, box in zip(masks.data, boxes):
                cls = int(box.cls[0])
                class_name = id_to_class.get(cls, "Unknown")  # 可能是transparent 可能是reflective

                # 根据类别为透明或反射,更新掩码
                mask_np = mask.cpu().numpy()  # 转换为numpy数组

                if class_name == 'transparent':
                    final_mask[mask_np > 0] = 1  # 透明区域设置为1
                elif class_name == 'reflective':
                    final_mask[mask_np > 0] = 2  # 反射区域设置为2

        # 保存结果掩码图像
        pred_image_path = os.path.join(pred_image_dir, img_path.replace('jpg', 'png'))
        cv2.imwrite(pred_image_path, final_mask)
        # masks_data = result.masks.data
        # for index, mask in enumerate(masks_data):
        #     mask = mask.cpu().numpy() * 255
        #     cv2.imwrite(pred_image_path, mask)

6.计算miou

由于Yolo官方没有miou的分割指标,我们这里基于上述5.生成的本地的mask和标签的mask计算miou。

import csv
import os
from os.path import join

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image


# 设标签宽W,长H
def fast_hist(a, b, n):
    # --------------------------------------------------------------------------------#
    #   a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
    # --------------------------------------------------------------------------------#
    k = (a >= 0) & (a < n)
    # --------------------------------------------------------------------------------#
    #   np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
    #   返回中,写对角线上的为分类正确的像素点
    # --------------------------------------------------------------------------------#
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)


def per_class_iu(hist):
    return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)


def per_class_PA_Recall(hist):
    return np.diag(hist) / np.maximum(hist.sum(1), 1)


def per_class_Precision(hist):
    return np.diag(hist) / np.maximum(hist.sum(0), 1)


def per_Accuracy(hist):
    return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)


def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
    print('Num classes', num_classes)
    # -----------------------------------------#
    #   创建一个全是0的矩阵,是一个混淆矩阵
    # -----------------------------------------#
    hist = np.zeros((num_classes, num_classes))

    # ------------------------------------------------#
    #   获得验证集标签路径列表,方便直接读取
    #   获得验证集图像分割结果路径列表,方便直接读取
    # ------------------------------------------------#
    gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
    pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]

    # ------------------------------------------------#
    #   读取每一个(图片-标签)对
    # ------------------------------------------------#
    for ind in range(len(gt_imgs)):
        # ------------------------------------------------#
        #   读取一张图像分割结果,转化成numpy数组
        # ------------------------------------------------#
        pred = np.array(Image.open(pred_imgs[ind]))
        # ------------------------------------------------#
        #   读取一张对应的标签,转化成numpy数组
        # ------------------------------------------------#
        label = np.array(Image.open(gt_imgs[ind]))

        # 如果图像分割结果与标签的大小不一样,这张图片就不计算
        if len(label.flatten()) != len(pred.flatten()):
            print(
                'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
                    len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
                    pred_imgs[ind]))
            continue

        # ------------------------------------------------#
        #   对一张图片计算21×21的hist矩阵,并累加
        # ------------------------------------------------#
        hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
        # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
        if name_classes is not None and ind > 0 and ind % 10 == 0:
            print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
                ind,
                len(gt_imgs),
                100 * np.nanmean(per_class_iu(hist)),
                100 * np.nanmean(per_class_PA_Recall(hist)),
                100 * per_Accuracy(hist)
            )
            )
    # ------------------------------------------------#
    #   计算所有验证集图片的逐类别mIoU值
    # ------------------------------------------------#
    IoUs = per_class_iu(hist)
    PA_Recall = per_class_PA_Recall(hist)
    Precision = per_class_Precision(hist)
    # ------------------------------------------------#
    #   逐类别输出一下mIoU值
    # ------------------------------------------------#
    if name_classes is not None:
        for ind_class in range(num_classes):
            print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
                  + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2)) + '; Precision-' + str(
                round(Precision[ind_class] * 100, 2)))

    # -----------------------------------------------------------------#
    #   在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
    # -----------------------------------------------------------------#
    print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(
        round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
    return np.array(hist, np.int_), IoUs, PA_Recall, Precision


def adjust_axes(r, t, fig, axes):
    bb = t.get_window_extent(renderer=r)
    text_width_inches = bb.width / fig.dpi
    current_fig_width = fig.get_figwidth()
    new_fig_width = current_fig_width + text_width_inches
    propotion = new_fig_width / current_fig_width
    x_lim = axes.get_xlim()
    axes.set_xlim([x_lim[0], x_lim[1] * propotion])


def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size=12, plt_show=True):
    fig = plt.gcf()
    axes = plt.gca()
    plt.barh(range(len(values)), values, color='royalblue')
    plt.title(plot_title, fontsize=tick_font_size + 2)
    plt.xlabel(x_label, fontsize=tick_font_size)
    plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
    r = fig.canvas.get_renderer()
    for i, val in enumerate(values):
        str_val = " " + str(val)
        if val < 1.0:
            str_val = " {0:.2f}".format(val)
        t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
        if i == (len(values) - 1):
            adjust_axes(r, t, fig, axes)

    fig.tight_layout()
    fig.savefig(output_path)
    if plt_show:
        plt.show()
    plt.close()


def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size=12):
    draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs) * 100), "Intersection over Union", \
                   os.path.join(miou_out_path, "mIoU.png"), tick_font_size=tick_font_size, plt_show=True)
    print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))

    draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Pixel Accuracy", \
                   os.path.join(miou_out_path, "mPA.png"), tick_font_size=tick_font_size, plt_show=False)
    print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))

    draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Recall", \
                   os.path.join(miou_out_path, "Recall.png"), tick_font_size=tick_font_size, plt_show=False)
    print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))

    draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision) * 100), "Precision", \
                   os.path.join(miou_out_path, "Precision.png"), tick_font_size=tick_font_size, plt_show=False)
    print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))

    with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
        writer = csv.writer(f)
        writer_list = []
        writer_list.append([' '] + [str(c) for c in name_classes])
        for i in range(len(hist)):
            writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
        writer.writerows(writer_list)
    print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))


if __name__ == "__main__":
    # --------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    # --------------------------------------------#
    name_classes = ["transparent", "reflective"]

    name_classes = ["background"] + name_classes
    num_classes = len(name_classes)
    # -------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    # -------------------------------------------------------#
    gt_dir = r"D:\bishe\TROSD\test\mask"
    image_ids = [os.path.splitext(file)[0] for file in os.listdir(gt_dir)]
    miou_out_path = "miou_out"
    os.makedirs(os.path.join(os.path.curdir, miou_out_path), exist_ok=True)
    pred_dir = r"D:\bishe\TROSD\test\predict_mask"
    print("Get miou.")
    hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
                                                    name_classes)  # 执行计算mIoU的函数
    print("Get miou done.")
    show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)

在这里插入图片描述

6.相关功能代码扩充

6.1 json文件转mask的.PNG

import os
import numpy as np
from PIL import Image, ImageDraw
# 图像和文本文件所在的目录
data_dir = r"D:\split_data\split_data\spike"

# 创建保存掩码图像的目录
# 保存二值mask
mask_dir = os.path.join(data_dir, "masks")
#保存填充的mask
mask_dir1 = os.path.join(data_dir, "maskss")
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(mask_dir1, exist_ok=True)
# 读取所有图像文件和对应的文本文件
image_files = sorted([f for f in os.listdir(os.path.join(data_dir, "images")) if f.endswith(".jpg")])
text_files = sorted([f for f in os.listdir(os.path.join(data_dir, "txt")) if f.endswith(".txt")])

# 遍历每个图像和对应的文本文件
for image_file, text_file in zip(image_files, text_files):
    # 读取图像和文本文件
    image_path = os.path.join(data_dir, "images", image_file)
    text_path = os.path.join(data_dir, "txt", text_file)
    image = Image.open(image_path)

    # 创建与原始图像大小相同的掩码图像
    mask = Image.new('L', (image.width, image.height), 0)
    draw = ImageDraw.Draw(mask)

    # 打开文本文件并遍历每一行
    with open(text_path, "r") as f:
        for line in f:
            # 解析文本文件中的类别和多点坐标信息
            parts = line.strip().split()
            category = int(parts[0])  # 类别信息
            points = list(map(float, parts[1:]))  # 坐标信息

            # 将多点坐标信息转换为图像上的绝对坐标
            xy = [(int(points[i] * image.width), int(points[i+1] * image.height)) for i in range(0,  len(points), 2)]

            # 绘制多边形区域,使用类别信息作为填充色
            draw.polygon(xy, outline='white', fill='white')

            # # 计算矩形的左上角和右下角坐标
            # x1, y1 = int((points[0]-points[2]/2) * image.width), int((points[1]-points[3]/2) * image.height)  # 左上角坐标
            # x2, y2 = int((points[0]+points[2]/2) * image.width), int((points[1]+points[3]/2) * image.height)  # 右下角坐标
            #
            # # 绘制矩形区域,使用类别信息作为填充色
            # draw.rectangle([x1, y1, x2, y2], outline='white', fill='white')

    # 保存掩码图像
    mask_path = os.path.join(mask_dir, os.path.splitext(image_file)[0] + ".png")
    mask.save(mask_path)
    mask = Image.open(mask_path)

    # 读取要填充的图像
    fill_image_path = image_path
    fill_image = Image.open(fill_image_path)

    # 在多边形区域内填充图像
    filled_image = Image.new('RGB', (image.width, image.height))
    filled_image.paste(image, mask=mask)  # 将原始图像放置在掩码图像内
    filled_image.paste(fill_image, mask=mask)  # 在掩码图像内填充指定的图像

    # 保存填充后的图像
    filled_image_path = os.path.join(mask_dir1, os.path.splitext(image_file)[0] + ".png")
    filled_image.save(filled_image_path)

这个是原始RGB图:
在这里插入图片描述

这个就是mask:
在这里插入图片描述
这个是填充了的mask:
在这里插入图片描述

6.2 实例分割数据增强

import Augmentor

# 确定原始图像存储路径以及掩码文件存储路径,需要把“\”改成“/”
p = Augmentor.Pipeline(r"D:\jiedan\split_data\split_data\exc_solder\png")
p.ground_truth(r"D:\jiedan\split_data\split_data\exc_solder\masks")

# 图像旋转: 按照概率0.8执行,范围在0-25之间
p.rotate(probability=0.8, max_left_rotation=25, max_right_rotation=25)

# 图像左右互换: 按照概率0.5执行
p.flip_left_right(probability=0.5)
p.flip_top_bottom(probability=0.5)

# 图像放大缩小: 按照概率0.8执行,面积为原始图0.85倍
p.zoom_random(probability=0.3, percentage_area=0.85)

# scale_factor表示缩放比例,只能大于1,且为等比放大。
p.scale(probability=1, scale_factor=1.3)

# 小块变形
p.random_distortion(probability=0.8, grid_width=10, grid_height=10, magnitude=20)

# 随机亮度增强/减弱,min_factor, max_factor为变化因子,决定亮度变化的程度,可根据效果指定
p.random_brightness(probability=1, min_factor=0.7, max_factor=1.2)

# 随机颜色/对比度增强/减弱
# p.random_color(probability=1, min_factor=0.0, max_factor=1)
p.random_contrast(probability=1, min_factor=0.7, max_factor=1.2)

# 随机剪切(shear)  max_shear_left,max_shear_right为剪切变换角度  范围0-25
p.shear(probability=1, max_shear_left=10, max_shear_right=10)

# 随机裁剪(random_crop)
p.crop_random(probability=1, percentage_area=0.8, randomise_percentage_area=True)

# 随机翻转(flip_random)
p.flip_random(probability=1)

# 最终扩充的数据样本数可以更换为1001000等
p.sample(1900)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Andrew_Xzw

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值