多类别语义分割数据,保持原标签的颜色
import os
import random
import glob
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import random
import datetime
def get_current_time():
'''
得到当前时间,精确到毫/微/纳秒,返回str
:return:
'''
current_time = datetime.datetime.now()
year = str(current_time.year)
month = str(current_time.month)
day = str(current_time.day)
hour = str(current_time.hour)
minute = str(current_time.minute)
second = str(current_time.second)
microsecond = str(current_time.microsecond)
current_time_str = year + month + day + hour + minute + second +"-"+microsecond
return current_time_str
class ImageAugmentation(object):
def __init__(self, image_aug_dir, segmentationClass_aug_dir, image_num):
self.image_aug_dir = image_aug_dir
self.segmentationClass_aug_dir = segmentationClass_aug_dir
self.image_start_num = 0 # 增强后图片的起始编号
self.seed_set()
if not os.path.exists(self.image_aug_dir):
os.mkdir(self.image_aug_dir)
if not os.path.exists(self.segmentationClass_aug_dir):
os.mkdir(self.segmentationClass_aug_dir)
self.img_num=image_num
def seed_set(self, seed=1):
np.random.seed(seed)
random.seed(seed)
ia.seed(seed)
def array2p_mode(self, alpha_channel):
"""alpha_channel is a binary image."""
# assert set(alpha_channel.flatten().tolist()) == {0, 1}, "alpha_channel is a binary image."
alpha_channel[alpha_channel == 1] = 128
h, w = alpha_channel.shape
image_arr = np.zeros((h, w, 3))
image_arr[:, :, 0] = alpha_channel
img = Image.fromarray(np.uint8(image_arr))
img_p = img.convert("P")
return img_p
def augmentor(self, image):
height, width, _ = image.shape
resize = iaa.Sequential([
iaa.Resize({"height": int(height/2), "width": int(width/2)}),
]) # 缩放
fliplr_flipud = iaa.Sequential([
iaa.Fliplr(),
iaa.Flipud(),
]) # 左右+上下翻转
rotate = iaa.Sequential([
iaa.Affine(rotate=(-70, 70))
]) # 旋转
translate = iaa.Sequential([
iaa.Affine(translate_percent=(0.2, 0.35))
]) # 平移
crop_and_pad = iaa.Sequential([
iaa.CropAndPad(percent=(-0.25, 0), keep_size=False),
]) # 裁剪
rotate_and_crop = iaa.Sequential([
iaa.Affine(rotate=45),
iaa.CropAndPad(percent=(-0.25, 0), keep_size=False)
]) # 旋转 + 裁剪
guassian_blur = iaa.Sequential([
iaa.GaussianBlur(sigma=(2, 3)),
]) # 增加高斯噪声
affin_=iaa.Sequential([
iaa.Affine(scale=(0.9, 1), translate_percent=(0, 0.1), rotate=(-40, 40), cval=0, mode='constant'), # 仿射变换
])
ops = [resize, fliplr_flipud, rotate, translate, crop_and_pad, rotate_and_crop, guassian_blur,affin_]
# 缩放、 镜像+上下翻转、 旋转、 xy平移、 裁剪、 旋转 + 裁剪、 高斯平滑
return ops
def augment_img(self, image_name, segmap_name):
# 1.Load an image.
image = Image.open(image_name) # RGB
segmap = Image.open(segmap_name) # P
name = f"{self.image_start_num:04d}"
image.save(self.image_aug_dir + name + ".jpg")
segmap.save(self.segmentationClass_aug_dir + name + ".png")
self.image_start_num += 1
image = np.array(image)
segmap = SegmentationMapsOnImage(np.array(segmap), shape=image.shape)
# 2. define the ops
ops = self.augmentor(image)
for n in range(self.img_num):
op=random.choice(ops)
name=get_current_time()
#name = f"{self.image_start_num:04d}"
print(f"当前增强了{self.image_start_num:04d}张数据...")
images_aug_i, segmaps_aug_i = op(image=image, segmentation_maps=segmap)
images_aug_i = Image.fromarray(images_aug_i)
images_aug_i.save(self.image_aug_dir + name + ".jpg")
segmaps_aug_i_ = segmaps_aug_i.get_arr()
segmaps_aug_i_[segmaps_aug_i_>0] = 1
segmaps_aug_i_ = self.array2p_mode(segmaps_aug_i_)
segmaps_aug_i_.save(self.segmentationClass_aug_dir + name + ".png")
self.image_start_num += 1
# 3.execute ths ops
# for _, op in enumerate(ops):
# name = f"{self.image_start_num:04d}"
# print(f"当前增强了{self.image_start_num:04d}张数据...")
# images_aug_i, segmaps_aug_i = op(image=image, segmentation_maps=segmap)
# images_aug_i = Image.fromarray(images_aug_i)
# images_aug_i.save(self.image_aug_dir + name + ".jpg")
# segmaps_aug_i_ = segmaps_aug_i.get_arr()
# segmaps_aug_i_[segmaps_aug_i_>0] = 1
# segmaps_aug_i_ = self.array2p_mode(segmaps_aug_i_)
# segmaps_aug_i_.save(self.segmentationClass_aug_dir + name + ".png")
# self.image_start_num += 1
def augment_images(self, image_dir, segmap_dir):
image_names = sorted(glob.glob(image_dir + "*"))
segmap_names = sorted(glob.glob(segmap_dir + "*"))
image_num = len(image_names)
count = 1
for image_name, segmap_name in zip(image_names, segmap_names):
print("*"*30, f"正在增强第【{count:04d}/{image_num:04d}】张图片...", "*"*30)
self.augment_img(image_name, segmap_name)
count += 1
if __name__ == "__main__":
image_dir = "E:/11/36000/tra1/JPEGImages/" # image目录,必须是RGB
segmap_dir = "E:/11/36000/tra1/SegmentationClass/" # segmap目录,必须是p模式的图片
image_aug_dir = "E:/11/36000/tra12/JPEG_AUG/" # image增强后存放目录
segmentationClass_aug_dir = "E:/11/36000/tra12/SegmentationClass_AUG/" # segmap增强后存放目录
os.makedirs(image_aug_dir,exist_ok=True)
os.makedirs(segmentationClass_aug_dir,exist_ok=True)
image_start_num = 12
image_augmentation = ImageAugmentation(image_aug_dir, segmentationClass_aug_dir, image_start_num)
image_augmentation.augment_images(image_dir, segmap_dir)