zoukankan      html  css  js  c++  java
  • Deep Learning -- 数据增强

    数据增强

      在图像的深度学习中,为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合),一般都会对数据图像进行数据增强,数据增强,常用的方式,就是旋转图像,剪切图像,改变图像色差,扭曲图像特征,改变图像尺寸大小,增强图像噪音(一般使用高斯噪音)等,但需要注意,不要加入其它图像轮廓的噪音。在不同的任务背景下,我们可以通过图像的几何变换,使用一下一种或者多种组合数据增强变换来增加输入数据的量。
    1. 旋转|反射变换(Rotation/reflection):随机旋转图像一定角度;改变图像的内容朝向;
    2. 翻转变换(flip):沿这水平或者垂直方向翻转图像
    3. 缩放变换(zoom):按照一定的比例放大或者缩小图像
    4. 平移变换(shift):在图像平面上对图像以一定方式进行平移

    数据增强的代码实现

    # -*- coding:utf-8 -*-
    # 数据增强
    # 1.翻转变换flip
    # 2.随机修剪random crop
    # 3.色彩抖动color jittering
    # 4.平移变换shift
    # 5.尺度变换scale
    # 6.对比度变换contrast
    # 7.噪声扰动noise
    # 8.旋转变换/反射变换 Rotation/reflection
    
    from PIL import Image,ImageEnhance,ImageOps,ImageFile
    import numpy as np
    import random
    import threading,os,time
    import logging
    
    logger = logging.getLogger(__name__)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    
    class DataAugmentation:
        #包含数据增强的八种方式
        def __init__(self):
            pass
    
        @staticmethod
        def openImage(image):
            return Image.open(image,mode="r")
    
        @staticmethod
        def randomRotation(image,mode=Image.BICUBIC):
            # 对图像进行任意0~360度旋转
            # param mode 邻近插值,双线性插值,双三次B样条插值(default)
            # param image PIL的图像image
            # return 旋转之后的图像
            random_angle = np.random.randint(1,360)
            return image.rotate(random_angle,mode)
    
        @staticmethod
        def randomCrop(image):
            #对图像随意剪切,考虑到图像大小范围(68*68),使用一个一个大于(36*36)的窗口进行截图
            #param image:PIL的图像image
            #return:剪切之后的图像
            image_width = image.size[0]
            image_height = image.size[1]
            crop_win_size = np.random.randint(40,68)
            random_region = ((image_width - crop_win_size ) >> 1 , (image_height - crop_win_size) >> 1 ,(image_width + crop_win_size) >> 1 , (image_height + crop_win_size) >> 1)
            return image.crop(random_region)
    
        @staticmethod
        def randomColor(image):
            #对图像进行颜色抖动
            #param image:PIL的图像image
            #return:有颜色色差的图像image
    
            #随机因子
            random_factor = np.random.randint(0, 31) / 10.
            #调整图像的饱和度
            color_image = ImageEnhance.Color(image).enhance(random_factor)
            #随机因子
            random_factor = np.random.randint(10,21) / 10.
            #调整图像的亮度
            brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)
            #随机因子
            random_factor = np.random.randint(10,21) / 10.
            #调整图像的对比度
            contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)
            #随机因子
            random_factor = np.random.randint(0,31) / 10.
            #调整图像锐度
            sharpness_image = ImageEnhance.Sharpness(contrast_image).enhance(random_factor)
            return sharpness_image
    
        @staticmethod
        def randomGaussian(image,mean=0.2,sigma=0.3):
            #对图像进行高斯噪声处理
            #param image:
            #return
    
            def gaussianNoisy(im,mean=0.2,sigma=0.3):
                #对图像做高斯噪音处理
                # param im:单通道图像
                # param mean:偏移量
                # param sigma:标准差
                #return:
                for _i in range(len(im)):
                    im[_i] += random.gauss(mean,sigma)
                return im
    
            #将图像转化为数组
            img = np.asanyarray(image)
            #将数组改为读写模式
            img.flags.writeable = True
            width,height = img.shape[:2]
            #对image的R,G,B三个通道进行分别处理
            img_r = gaussianNoisy(img[:,:,0].flatten(), mean, sigma)
            img_g = gaussianNoisy(img[:,:,1].flatten(), mean, sigma)
            img_b = gaussianNoisy(img[:,:,2].flatten(), mean, sigma)
            img[:,:,0] = img_r.reshape([width,height])
            img[:,:,1] = img_g.reshape([width,height])
            img[:,:,2] = img_b.reshape([width,height])
            return Image.fromarray(np.uint8(img))
    
        @staticmethod
        def saveImage(image,path):
            image.save(path)
    
    def makeDir(path):
        try:
            if not os.path.exists(path):
                if not os.path.isfile(path):
                    os.makdirs(path)
                return 0
            else:
                return 1
        except Exception, e:
            print str(e)
            return -1
    
    def imageOps(func_name, image, des_path, file_name, times = 5):
        funcMap = {"randomRotation": DataAugmentation.randomRotation,
                    "randomCrop":DataAugmentation.randomCrop,
                    "randomColor":DataAugmentation.randomColor,
                    "randomGaussian":DataAugmentation.randomGaussian
                    }
        if funcMap.get(func_name) is None:
            logger.error("%s is not exist" , func_name)
            return -1
    
        for _i in range(0,times,1):
            new_image = funcMap[func_name](image)
            DataAugmentation.saveImage(new_image,os.path.join(des_path,func_name + str(_i) + file_name))
    
    opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}
    
    def threadOPS(path,new_path):
        #多线程处理事务
        #param src_path:资源文件
        #param des_path:目的地文件
        #return:
    
        if os.path.isdir(path):
            img_names = os.listdir(path)
        else:
            img_names = [path]
        for img_name in img_names:
            print img_name
            tmp_img_name = os.path.join(path,img_name)
            print tmp_img_name
            if os.path.isdir(tmp_img_name):
                if makeDir(os.path.join(new_path,img_name)) != -1:
                    threadOPS(tmp_img_name,os.path.join(new_path,img_name))
                else:
                    print 'create new dir failure'
                    return -1
            elif tmp_img_name.split('.')[1] != "DS_Store":
                image = DataAugmentation.openImage(tmp_img_name)
                threadImage = [0] * 5
                _index = 0
                for ops_name in opsList:
                    threadImage[_index] = threading.Thread(target=imageOps,args=(ops_name,image,new_path,img_name))
                    threadImage[_index].start()
                    _index += 1
                    time.sleep(0.2)
    
    if __name__ == '__main__':
        threadOPS("C:UsersAcheronPycharmProjectsCNNpic-image\trainimages","C:UsersAcheronPycharmProjectsCNNpic-image\train\newimages")

    数据增强实验

    原始的待进行数据增强的图像:

    1.对图像进行颜色抖动

     2.对图像进行高斯噪声处理

  • 相关阅读:
    水木→函数式编程语言→lisp是不是主要用来编网站的?
    OpenMP 维基百科,自由的百科全书
    一个实际的Lisp项目开发心得 albert_lee的产品技术空间 博客频道 CSDN.NET
    ...
    OpenMPI
    Debian下安装NetBeans
    Linux Socket学习(十七)
    Linux Socket学习(十四)
    Debian下安装Latex
    Debian下安装virtualbox
  • 原文地址:https://www.cnblogs.com/fangpengchengbupter/p/7649627.html
Copyright © 2011-2022 走看看