首页 > 代码库 > 深度学习之图像的数据增强

深度学习之图像的数据增强

   在图像的深度学习中,为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合),一般都会对数据图像进行数据增强,

数据增强,常用的方式,就是旋转图像,剪切图像,改变图像色差,扭曲图像特征,改变图像尺寸大小,增强图像噪音(一般使用高斯噪音,盐椒噪音)等.

但是需要注意,不要加入其他图像轮廓的噪音.

  对于常用的图像的数据增强的实现,如下:

技术分享
  1 # -*- coding:utf-8 -*-  2 """数据增强  3    1. 翻转变换 flip  4    2. 随机修剪 random crop  5    3. 色彩抖动 color jittering  6    4. 平移变换 shift  7    5. 尺度变换 scale  8    6. 对比度变换 contrast  9    7. 噪声扰动 noise 10    8. 旋转变换/反射变换 Rotation/reflection 11    author: XiJun.Gong 12    date:2016-11-29 13 """ 14  15 from PIL import Image, ImageEnhance, ImageOps, ImageFile 16 import numpy as np 17 import random 18 import threading, os, time 19 import logging 20  21 logger = logging.getLogger(__name__) 22 ImageFile.LOAD_TRUNCATED_IMAGES = True 23  24  25 class DataAugmentation: 26     """ 27     包含数据增强的八种方式 28     """ 29  30  31     def __init__(self): 32         pass 33  34     @staticmethod 35     def openImage(image): 36         return Image.open(image, mode="r") 37  38     @staticmethod 39     def randomRotation(image, mode=Image.BICUBIC): 40         """ 41          对图像进行随机任意角度(0~360度)旋转 42         :param mode 邻近插值,双线性插值,双三次B样条插值(default) 43         :param image PIL的图像image 44         :return: 旋转转之后的图像 45         """ 46         random_angle = np.random.randint(1, 360) 47         return image.rotate(random_angle, mode) 48  49     @staticmethod 50     def randomCrop(image): 51         """ 52         对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图 53         :param image: PIL的图像image 54         :return: 剪切之后的图像 55  56         """ 57         image_width = image.size[0] 58         image_height = image.size[1] 59         crop_win_size = np.random.randint(40, 68) 60         random_region = ( 61             (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1, 62             (image_height + crop_win_size) >> 1) 63         return image.crop(random_region) 64  65     @staticmethod 66     def randomColor(image): 67         """ 68         对图像进行颜色抖动 69         :param image: PIL的图像image 70         :return: 有颜色色差的图像image 71         """ 72         random_factor = np.random.randint(0, 31) / 10.  # 随机因子 73         color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度 74         random_factor = np.random.randint(10, 21) / 10.  # 随机因子 75         brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度 76         random_factor = np.random.randint(10, 21) / 10.  # 随机因1子 77         contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度 78         random_factor = np.random.randint(0, 31) / 10.  # 随机因子 79         return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度 80  81     @staticmethod 82     def randomGaussian(image, mean=0.2, sigma=0.3): 83         """ 84          对图像进行高斯噪声处理 85         :param image: 86         :return: 87         """ 88  89         def gaussianNoisy(im, mean=0.2, sigma=0.3): 90             """ 91             对图像做高斯噪音处理 92             :param im: 单通道图像 93             :param mean: 偏移量 94             :param sigma: 标准差 95             :return: 96             """ 97             for _i in range(len(im)): 98                 im[_i] += random.gauss(mean, sigma) 99             return im100 101         # 将图像转化成数组102         img = np.asarray(image)103         img.flags.writeable = True  # 将数组改为读写模式104         width, height = img.shape[:2]105         img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)106         img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)107         img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)108         img[:, :, 0] = img_r.reshape([width, height])109         img[:, :, 1] = img_g.reshape([width, height])110         img[:, :, 2] = img_b.reshape([width, height])111         return Image.fromarray(np.uint8(img))112 113     @staticmethod114     def saveImage(image, path):115         image.save(path)116 117 118 def makeDir(path):119     try:120         if not os.path.exists(path):121             if not os.path.isfile(path):122                 # os.mkdir(path)123                 os.makedirs(path)124             return 0125         else:126             return 1127     except Exception, e:128         print str(e)129         return -2130 131 132 def imageOps(func_name, image, des_path, file_name, times=5):133     funcMap = {"randomRotation": DataAugmentation.randomRotation,134                "randomCrop": DataAugmentation.randomCrop,135                "randomColor": DataAugmentation.randomColor,136                "randomGaussian": DataAugmentation.randomGaussian137                }138     if funcMap.get(func_name) is None:139         logger.error("%s is not exist", func_name)140         return -1141 142     for _i in range(0, times, 1):143         new_image = funcMap[func_name](image)144         DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))145 146 147 opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}148 149 150 def threadOPS(path, new_path):151     """152     多线程处理事务153     :param src_path: 资源文件154     :param des_path: 目的地文件155     :return:156     """157     if os.path.isdir(path):158         img_names = os.listdir(path)159     else:160         img_names = [path]161     for img_name in img_names:162         print img_name163         tmp_img_name = os.path.join(path, img_name)164         if os.path.isdir(tmp_img_name):165             if makeDir(os.path.join(new_path, img_name)) != -1:166                 threadOPS(tmp_img_name, os.path.join(new_path, img_name))167             else:168                 print create new dir failure169                 return -1170                 # os.removedirs(tmp_img_name)171         elif tmp_img_name.split(.)[1] != "DS_Store":172             # 读取文件并进行操作173             image = DataAugmentation.openImage(tmp_img_name)174             threadImage = [0] * 5175             _index = 0176             for ops_name in opsList:177                 threadImage[_index] = threading.Thread(target=imageOps,178                                                        args=(ops_name, image, new_path, img_name,))179                 threadImage[_index].start()180                 _index += 1181                 time.sleep(0.2)182 183 184 if __name__ == __main__:185     threadOPS("/home/pic-image/train/12306train",186               "/home/pic-image/train/12306train3")
View Code

 

 

深度学习之图像的数据增强