zoukankan      html  css  js  c++  java
  • 深度学习中的图像增强

    转自:https://www.cnblogs.com/gongxijun/p/6117588.html?utm_source=itdadao&utm_medium=referral

      1 # -*- coding:utf-8 -*-
      2 """数据增强
      3    1. 翻转变换 flip
      4    2. 随机修剪 random crop
      5    3. 色彩抖动 color jittering
      6    4. 平移变换 shift
      7    5. 尺度变换 scale
      8    6. 对比度变换 contrast
      9    7. 噪声扰动 noise
     10    8. 旋转变换/反射变换 Rotation/reflection
     11    author: XiJun.Gong
     12    date:2016-11-29
     13 """
     14 
     15 from PIL import Image, ImageEnhance, ImageOps, ImageFile
     16 import numpy as np
     17 import random
     18 import threading, os, time
     19 import logging
     20 
     21 logger = logging.getLogger(__name__)
     22 ImageFile.LOAD_TRUNCATED_IMAGES = True
     23 
     24 
     25 class DataAugmentation:
     26     """
     27     包含数据增强的八种方式
     28     """
     29 
     30 
     31     def __init__(self):
     32         pass
     33 
     34     @staticmethod
     35     def openImage(image):
     36         return Image.open(image, mode="r")
     37 
     38     @staticmethod
     39     def randomRotation(image, mode=Image.BICUBIC):
     40         """
     41          对图像进行随机任意角度(0~360度)旋转
     42         :param mode 邻近插值,双线性插值,双三次B样条插值(default)
     43         :param image PIL的图像image
     44         :return: 旋转转之后的图像
     45         """
     46         random_angle = np.random.randint(1, 360)
     47         return image.rotate(random_angle, mode)
     48 
     49     @staticmethod
     50     def randomCrop(image):
     51         """
     52         对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图
     53         :param image: PIL的图像image
     54         :return: 剪切之后的图像
     55 
     56         """
     57         image_width = image.size[0]
     58         image_height = image.size[1]
     59         crop_win_size = np.random.randint(40, 68)
     60         random_region = (
     61             (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
     62             (image_height + crop_win_size) >> 1)
     63         return image.crop(random_region)
     64 
     65     @staticmethod
     66     def randomColor(image):
     67         """
     68         对图像进行颜色抖动
     69         :param image: PIL的图像image
     70         :return: 有颜色色差的图像image
     71         """
     72         random_factor = np.random.randint(0, 31) / 10.  # 随机因子
     73         color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
     74         random_factor = np.random.randint(10, 21) / 10.  # 随机因子
     75         brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
     76         random_factor = np.random.randint(10, 21) / 10.  # 随机因1子
     77         contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
     78         random_factor = np.random.randint(0, 31) / 10.  # 随机因子
     79         return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度
     80 
     81     @staticmethod
     82     def randomGaussian(image, mean=0.2, sigma=0.3):
     83         """
     84          对图像进行高斯噪声处理
     85         :param image:
     86         :return:
     87         """
     88 
     89         def gaussianNoisy(im, mean=0.2, sigma=0.3):
     90             """
     91             对图像做高斯噪音处理
     92             :param im: 单通道图像
     93             :param mean: 偏移量
     94             :param sigma: 标准差
     95             :return:
     96             """
     97             for _i in range(len(im)):
     98                 im[_i] += random.gauss(mean, sigma)
     99             return im
    100 
    101         # 将图像转化成数组
    102         img = np.asarray(image)
    103         img.flags.writeable = True  # 将数组改为读写模式
    104         width, height = img.shape[:2]
    105         img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
    106         img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
    107         img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
    108         img[:, :, 0] = img_r.reshape([width, height])
    109         img[:, :, 1] = img_g.reshape([width, height])
    110         img[:, :, 2] = img_b.reshape([width, height])
    111         return Image.fromarray(np.uint8(img))
    112 
    113     @staticmethod
    114     def saveImage(image, path):
    115         image.save(path)
    116 
    117 
    118 def makeDir(path):
    119     try:
    120         if not os.path.exists(path):
    121             if not os.path.isfile(path):
    122                 # os.mkdir(path)
    123                 os.makedirs(path)
    124             return 0
    125         else:
    126             return 1
    127     except Exception, e:
    128         print str(e)
    129         return -2
    130 
    131 
    132 def imageOps(func_name, image, des_path, file_name, times=5):
    133     funcMap = {"randomRotation": DataAugmentation.randomRotation,
    134                "randomCrop": DataAugmentation.randomCrop,
    135                "randomColor": DataAugmentation.randomColor,
    136                "randomGaussian": DataAugmentation.randomGaussian
    137                }
    138     if funcMap.get(func_name) is None:
    139         logger.error("%s is not exist", func_name)
    140         return -1
    141 
    142     for _i in range(0, times, 1):
    143         new_image = funcMap[func_name](image)
    144         DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))
    145 
    146 
    147 opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}
    148 
    149 
    150 def threadOPS(path, new_path):
    151     """
    152     多线程处理事务
    153     :param src_path: 资源文件
    154     :param des_path: 目的地文件
    155     :return:
    156     """
    157     if os.path.isdir(path):
    158         img_names = os.listdir(path)
    159     else:
    160         img_names = [path]
    161     for img_name in img_names:
    162         print img_name
    163         tmp_img_name = os.path.join(path, img_name)
    164         if os.path.isdir(tmp_img_name):
    165             if makeDir(os.path.join(new_path, img_name)) != -1:
    166                 threadOPS(tmp_img_name, os.path.join(new_path, img_name))
    167             else:
    168                 print 'create new dir failure'
    169                 return -1
    170                 # os.removedirs(tmp_img_name)
    171         elif tmp_img_name.split('.')[1] != "DS_Store":
    172             # 读取文件并进行操作
    173             image = DataAugmentation.openImage(tmp_img_name)
    174             threadImage = [0] * 5
    175             _index = 0
    176             for ops_name in opsList:
    177                 threadImage[_index] = threading.Thread(target=imageOps,
    178                                                        args=(ops_name, image, new_path, img_name,))
    179                 threadImage[_index].start()
    180                 _index += 1
    181                 time.sleep(0.2)
    182 
    183 
    184 if __name__ == '__main__':
    185     threadOPS("/home/pic-image/train/12306train",
    186               "/home/pic-image/train/12306train3")
  • 相关阅读:
    jdbc连接Sql server数据库,并查询数据
    HttpClient,post请求,发送json,并接收数据
    SQL SERVER存储过程一
    HttpClient,get请求,发送并接收数据
    工作中操作数据库实例
    存储过程的实例(公司)
    eclipse发布项目后,项目所在的位置
    SQLSERVER存储过程基本语法
    SAXReader解析
    导包
  • 原文地址:https://www.cnblogs.com/abella/p/10315700.html
Copyright © 2011-2022 走看看