zoukankan      html  css  js  c++  java
  • data augmentation 总结

    data augmentation 几种方法总结

    在深度学习中,有的时候训练集不够多,或者某一类数据较少,或者为了防止过拟合,让模型更加鲁棒性,data augmentation是一个不错的选择。

    常见方法

    Color Jittering:对颜色的数据增强:图像亮度、饱和度、对比度变化(此处对色彩抖动的理解不知是否得当);

    PCA Jittering:首先按照RGB三个颜色通道计算均值和标准差,再在整个训练集上计算协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering;

    Random Scale:尺度变换;

    Random Crop:采用随机图像差值方式,对图像进行裁剪、缩放;包括Scale Jittering方法(VGG及ResNet模型使用)或者尺度和长宽比增强变换;

    Horizontal/Vertical Flip:水平/垂直翻转;

    Shift:平移变换;

    Rotation/Reflection:旋转/仿射变换;

    Noise:高斯噪声、模糊处理;

    Label shuffle:类别不平衡数据的增广,参见海康威视ILSVRC2016的report;另外,文中提出了一种Supervised Data Augmentation方法,有兴趣的朋友的可以动手实验下。

    部分方法的具体实现

    # -*- coding:utf-8 -*-
    """数据增强
       1. 翻转变换 flip
       2. 随机修剪 random crop
       3. 色彩抖动 color jittering
       4. 平移变换 shift
       5. 尺度变换 scale
       6. 对比度变换 contrast
       7. 噪声扰动 noise
       8. 旋转变换/反射变换 Rotation/reflection
    """
    
    from PIL import Image, ImageEnhance, ImageOps, ImageFile
    import numpy as np
    import random
    import threading, os, time
    import logging
    
    logger = logging.getLogger(__name__)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    
    
    class DataAugmentation:
        """
        包含数据增强的八种方式
        """
    
    
        def __init__(self):
            pass
    
        @staticmethod
        def openImage(image):
            return Image.open(image, mode="r")
    
        @staticmethod
        def randomRotation(image, mode=Image.BICUBIC):
            """
             对图像进行随机任意角度(0~360度)旋转
            :param mode 邻近插值,双线性插值,双三次B样条插值(default)
            :param image PIL的图像image
            :return: 旋转转之后的图像
            """
            random_angle = np.random.randint(1, 360)
            return image.rotate(random_angle, mode)
    
        @staticmethod
        def randomCrop(image):
            """
            对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图
            :param image: PIL的图像image
            :return: 剪切之后的图像
    
            """
            image_width = image.size[0]
            image_height = image.size[1]
            crop_win_size = np.random.randint(40, 68)
            random_region = (
                (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
                (image_height + crop_win_size) >> 1)
            return image.crop(random_region)
    
        @staticmethod
        def randomColor(image):
            """
            对图像进行颜色抖动
            :param image: PIL的图像image
            :return: 有颜色色差的图像image
            """
            random_factor = np.random.randint(0, 31) / 10.  # 随机因子
            color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
            random_factor = np.random.randint(10, 21) / 10.  # 随机因子
            brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
            random_factor = np.random.randint(10, 21) / 10.  # 随机因1子
            contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
            random_factor = np.random.randint(0, 31) / 10.  # 随机因子
            return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度
    
        @staticmethod
        def randomGaussian(image, mean=0.2, sigma=0.3):
            """
             对图像进行高斯噪声处理
            :param image:
            :return:
            """
    
            def gaussianNoisy(im, mean=0.2, sigma=0.3):
                """
                对图像做高斯噪音处理
                :param im: 单通道图像
                :param mean: 偏移量
                :param sigma: 标准差
                :return:
                """
                for _i in range(len(im)):
                    im[_i] += random.gauss(mean, sigma)
                return im
    
            # 将图像转化成数组
            img = np.asarray(image)
            img.flags.writeable = True  # 将数组改为读写模式
            width, height = img.shape[:2]
            img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
            img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
            img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
            img[:, :, 0] = img_r.reshape([width, height])
            img[:, :, 1] = img_g.reshape([width, height])
            img[:, :, 2] = img_b.reshape([width, height])
            return Image.fromarray(np.uint8(img))
    
        @staticmethod
        def saveImage(image, path):
            image.save(path)
    
    
    def makeDir(path):
        try:
            if not os.path.exists(path):
                if not os.path.isfile(path):
                    # os.mkdir(path)
                    os.makedirs(path)
                return 0
            else:
                return 1
        except Exception, e:
            print str(e)
            return -2
    
    
    def imageOps(func_name, image, des_path, file_name, times=5):
        funcMap = {"randomRotation": DataAugmentation.randomRotation,
                   "randomCrop": DataAugmentation.randomCrop,
                   "randomColor": DataAugmentation.randomColor,
                   "randomGaussian": DataAugmentation.randomGaussian
                   }
        if funcMap.get(func_name) is None:
            logger.error("%s is not exist", func_name)
            return -1
    
        for _i in range(0, times, 1):
            new_image = funcMap[func_name](image)
            DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))
    
    
    opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}
    
    
    def threadOPS(path, new_path):
        """
        多线程处理事务
        :param src_path: 资源文件
        :param des_path: 目的地文件
        :return:
        """
        if os.path.isdir(path):
            img_names = os.listdir(path)
        else:
            img_names = [path]
        for img_name in img_names:
            print img_name
            tmp_img_name = os.path.join(path, img_name)
            if os.path.isdir(tmp_img_name):
                if makeDir(os.path.join(new_path, img_name)) != -1:
                    threadOPS(tmp_img_name, os.path.join(new_path, img_name))
                else:
                    print 'create new dir failure'
                    return -1
                    # os.removedirs(tmp_img_name)
            elif tmp_img_name.split('.')[1] != "DS_Store":
                # 读取文件并进行操作
                image = DataAugmentation.openImage(tmp_img_name)
                threadImage = [0] * 5
                _index = 0
                for ops_name in opsList:
                    threadImage[_index] = threading.Thread(target=imageOps,
                                                           args=(ops_name, image, new_path, img_name,))
                    threadImage[_index].start()
                    _index += 1
                    time.sleep(0.2)
    
    
    if __name__ == '__main__':
        threadOPS("/home/pic-image/train/12306train",
                  "/home/pic-image/train/12306train3")
    

    参考文献

    深度学习之图像的数据增强
    知乎

  • 相关阅读:
    [Spring框架]Spring 事务管理基础入门总结.
    [Maven]Eclipse插件之Maven配置及问题解析.
    编程哲学的几点感悟
    用ASP.NET Core MVC 和 EF Core 构建Web应用 (十)
    用ASP.NET Core MVC 和 EF Core 构建Web应用 (九)
    C# 中的特性 Attribute
    用ASP.NET Core MVC 和 EF Core 构建Web应用 (八)
    用ASP.NET Core MVC 和 EF Core 构建Web应用 (七)
    C# 调用 WebApi
    用ASP.NET Core MVC 和 EF Core 构建Web应用 (六)
  • 原文地址:https://www.cnblogs.com/zhonghuasong/p/7256498.html
Copyright © 2011-2022 走看看