图像方面的数据增强可以从下面几个角度来看.
- 仿射变换 (Random crop, Random flip, Random rotation, Random zoom, Random shear, Random translation…)
- 彩色失真 (Random gamma, Random brightness, Random hue, Random contrast, Gaussian Noise …)
- 信息丢弃(Gridmask, Cutout, Random Erasing, Hide-and-seek…)
- 多图融合(Mixup, Cutmix, Fmix…)
- 另类 (Augmix …)
这里我整理了现在流行的大部分数据增强, 并通过Tensorflow 2 实现. 使用Tensorflow API实现的数据增强在训练时会被GPU/TPU加速, 远比numpy实现(跑在CPU端)的要快.
另外, 有些数据增强(Fmix, Augmix)的Tensorflow版本, 我自己也还没有跑通, 等跑通了, 后再传上来.
加载数据集用于测试
import numpy as np
import tensorflow as tf
from tensorflow.keras import *
import tensorflow.keras.backend as B
import matplotlib.pyplot as plt
import math
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.repeat(np.expand_dims(x_train / 255.0 / 2 , -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
x_test = np.repeat(np.expand_dims(x_test / 255.0 / 2, -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
y_train = np.eye(10)[np.reshape(y_train, -1)].astype(np.float32) # one-hot
y_test = np.eye(10)[np.reshape(y_test, - 1)].astype(np.float32) # one-hot
print(x_train.shape, y_train.shape) # (60000, 28, 28, 3) (60000, 10)
print(x_test.shape, y_test.shape) # (10000, 28, 28, 3) (10000, 10)
def xy_visualization(x, y):
plt.imshow(x)
plt.axis('off')
plt.show()
print(y.numpy())
颜色方面(调节对比度, 调节亮度, 调节Hue, 添加饱和度, 添加高斯噪音)
@tf.function
def random_contrast(img_batch, lower = 0.9, upper=1.0):
# (x - mean) * contrast_factor + mean.
return tf.image.random_contrast(img_batch, lower, upper)
@tf.function
def random_brightness(img_batch, delta = 0.05):
# x + delta
return tf.image.random_brightness(img_batch, delta)
@tf.function
def random_hue(img_batch, delta = 0.01):
# 将RGB图像转换为浮点表示形式,将其转换为HSV,向色调通道添加偏移量,再转换回RGB
return tf.image.random_hue(img_batch, delta)
@tf.function
def random_saturation(img_batch, lower = 0.0, upper=0.01):
# 将RGB图像转换为浮点表示形式,将其转换为HSV,向饱和通道添加偏移量,再转换回RGB,然后返回原始数据类型.
return tf.image.random_saturation(img_batch, lower, upper)
@tf.function
def random_gaussian_noise(img_batch, noise_scale = 0.01):
gaussian_noise = (noise_scale * tf.random.normal(tf.shape(img_batch), mean=0.0, stddev=1., dtype=tf.float32))
return tf.clip_by_value(img_batch + gaussian_noise, 0., 1.)
@tf.function
def color_aug(img_batch, label_batch):
# img_batch [N, image_h, img_w, img_channels]
# label_batch [N, num_classes]
img_batch = random_contrast(img_batch)
img_batch = random_brightness(img_batch)
img_batch = random_hue(img_batch)
img_batch = random_saturation(img_batch)
img_batch = random_gaussian_noise(img_batch)
return img_batch, label_batch
batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : color_aug(a, b))
x, y = batch.__iter__().__next__()
for i in range(5):
xy_visualization(x[i], y[i])
效果
仿射变换 (随机缩放, 随机剪裁, 随机翻转, 随机旋转, 随机错切, 随机偏移)
仿射变换原理请参考: https://www.cnblogs.com/happystudyeveryday/p/10547316.html
@tf.function
def aug_affine(img_batch, label_batch, rotation=2, shear=0, zoom=0, shift=16, flip=3