zoukankan      html  css  js  c++  java
  • 训练识别 数据增强方法

    以下是为了训练识别,轻微的数据增强方法

    import os
    import cv2
    import numpy as np
    import random
    
    
    def colorjitter(img):
        '''
        ### Different Color Jitter ###
        img: image
        cj_type: {b: brightness, s: saturation, c: constast}
    
        '''
        print("==========colorjitter====================")
    
        list_type = ["b","s","c"]
        cj_type =random.choice(list_type)
        if cj_type == "b":
            # value = random.randint(-50, 50)
            value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            if value >= 0:
                lim = 255 - value
                v[v > lim] = 255
                v[v <= lim] += value
            else:
                lim = np.absolute(value)
                v[v < lim] = 0
                v[v >= lim] -= np.absolute(value)
    
            final_hsv = cv2.merge((h, s, v))
            img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
            return img
    
        elif cj_type == "s":
            # value = random.randint(-50, 50)
            value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            if value >= 0:
                lim = 255 - value
                s[s > lim] = 255
                s[s <= lim] += value
            else:
                lim = np.absolute(value)
                s[s < lim] = 0
                s[s >= lim] -= np.absolute(value)
    
            final_hsv = cv2.merge((h, s, v))
            img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
            return img
    
        elif cj_type == "c":
            brightness = 10
            contrast = random.randint(40, 100)
            dummy = np.int16(img)
            dummy = dummy * (contrast / 127 + 1) - contrast + brightness
            dummy = np.clip(dummy, 0, 255)
            img = np.uint8(dummy)
            return img
    
    
    def noisy(img):
        '''
        ### Adding Noise ###
        img: image
        cj_type: {gauss: gaussian, sp: salt & pepper}
    
        '''
        print("==========noisy====================")
    
        list_type = ["gauss", "sp"]
        noise_type = random.choice(list_type)
    
        if noise_type == "gauss":
            image = img.copy()
            mean = 0
            st = 0.7
            gauss = np.random.normal(mean, st, image.shape)
            gauss = gauss.astype('uint8')
            image = cv2.add(image, gauss)
            return image
    
        elif noise_type == "sp":
            image = img.copy()
            prob = 0.05
            if len(image.shape) == 2:
                black = 0
                white = 255
            else:
                colorspace = image.shape[2]
                if colorspace == 3:  # RGB
                    black = np.array([0, 0, 0], dtype='uint8')
                    white = np.array([255, 255, 255], dtype='uint8')
                else:  # RGBA
                    black = np.array([0, 0, 0, 255], dtype='uint8')
                    white = np.array([255, 255, 255, 255], dtype='uint8')
            probs = np.random.random(image.shape[:2])
            image[probs < (prob / 2)] = black
            image[probs > 1 - (prob / 2)] = white
            return image
    
    
    def filters(img):
        '''
        ### Filtering ###
        img: image
        f_type: {blur: blur, gaussian: gaussian, median: median}
    
        '''
    
        print("==========filters====================")
    
        list_type = ["blur", "gaussian", "median"]
        f_type = random.choice(list_type)
        # print(f_type)
    
        if f_type == "blur":
            image = img.copy()
            fsize = 5
            return cv2.blur(image, (fsize, fsize))
    
        elif f_type == "gaussian":
            image = img.copy()
            fsize = 5
            return cv2.GaussianBlur(image, (fsize, fsize), 0)
    
        elif f_type == "median":
            image = img.copy()
            fsize = 5
            return cv2.medianBlur(image, fsize)
    
    
    def gaussain_noise(img):
        print("==========gaussain_noise====================")
        img = img.astype(np.uint8)
        h, w, c = img.shape
        list_var = [0.4,0.38,0.22,20,26,2,3,4,5,6,12,14,16,17,9]
        var = random.choice(list_var)
        list_mean = [0,0.5,0.08,0.5,15,1,2,3,4,5,6,7,8]
        mean = random.choice(list_mean)
        # print(var,mean)
        sigma = var ** 0.5
        gauss = np.random.normal(mean, sigma, (h, w, c))
        gauss = gauss.reshape(h, w, c).astype(np.uint8)
        noisy = img + gauss
        return noisy
    
    def img_contrast(img):
        print("==========img_contrast====================")
        min_s, max_s, min_v, max_v = 0,25,0,30
        img = img.astype(np.uint8)
        hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        _s = random.randint(min_s, max_s)
        _v = random.randint(min_v, max_v)
        if _s >= 0 :
            hsv_img[:, :, 1] += _s
        else :
            _s = - _s
            hsv_img[:, :, 1] -= _s
        if _v >= 0 :
            hsv_img[:, :, 2] += _v
        else :
            _v = - _v
            hsv_img[:, :, 2] += _v
        out = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
        return out
    
    def rotate_func(img):
        '''
        like PIL, rotate by degree, not radians
        '''
        print("==========rotate_func====================")
        fill = (255, 255, 255)
        list_ang = [0.4,-0.4,0.25,-0.25,0.3,-0.3,0.45,-0.45,0.6,-0.6,0.7,-0.7,0.8,-0.8,0.9,-0.9,0.168,-0.852,0.5,-0.5,1,-1,0.8,-0.8,1.5,-1.5,1.2,-1.2,1.4,-1.4]
        degree = random.choice(list_ang)
        # print(degree)
        H, W = img.shape[0], img.shape[1]
        center = W / 2, H / 2
        M = cv2.getRotationMatrix2D(center, degree, 1)
        out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
        return out
    
    
    
    dir_img = "/data_1/everyday/0507/123/"
    list_img = os.listdir(dir_img)
    for img_name in list_img:
        path_img = dir_img + img_name
    
        path_img = "/data_1/everyday/0507/123/0105_None_LBVTZ0100KST87851_B48B20D_1998_20190402_宝马牌_BMW6475JX_165_156_5"
        img = cv2.imread(path_img)
    
        while True:
    
            fun_apply_list = [colorjitter,noisy,filters,gaussain_noise,img_contrast,rotate_func]
            fun_apply = random.choice(fun_apply_list)
    
            img_aug = fun_apply(img)
    
            cv2.imshow("img_src",img)
            cv2.imshow("img_aug", img_aug)
            cv2.waitKey(0)
    
    

    参考github

    https://github.com/Canjie-Luo/Text-Image-Augmentation
    https://github.com/AISangam/Image-Augmentation-Using-OpenCV-and-Python
    https://github.com/FengYen-Chang/Data-Augmentation/blob/master/DataAugmentation/DataAugment.py
    https://github.com/CoinCheung/AutoAugment_opencv/blob/master/AA_classification/functions.py

    好记性不如烂键盘---点滴、积累、进步!
  • 相关阅读:
    类型转换
    Java中this和super的用法总结
    关于网页乱码问题
    用cookie实现记住用户名和密码
    Before start of result set
    jsp页面错误The attribute prefix does not correspond to any imported tag library
    MySql第几行到第几行语句
    servelet跳转页面的路径中一直包含sevelet的解决办法
    <a>标签跳转到Servelet页面并实现参数的传递
    解决网页乱码
  • 原文地址:https://www.cnblogs.com/yanghailin/p/14742464.html
Copyright © 2011-2022 走看看