zoukankan      html  css  js  c++  java
  • 数据增强---CutMix

    CutMix

    CutMix是在论文《CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features》被提出的数据增强方式,常用于分类任务和检测任务。



    什么是CutMix

    Cut指切割出图片中的一小块,MIx指将这一小块贴到其他图片中,并且label也会进行混合。

    从下图可以看出CutMix对模型分类准确率和定位准确率有明显的提升。

    CutMix的操作可以用如下公式表示:

    [egin{align} ar x &= M odot x_A + (1-M)odot x_B \ ar y &= lambda y_A + (1-lambda)y_B end{align}]

    其中的符号解释如下:

    • (M)是一个二值Mask。对于(x_A)(M=1)部分的图像会被保留。对于(x_B)(M=0)的部分会被保留
    • (x_A,x_B)分别是两张图片
    • (y_A,y_B)是对应的label
    • (ar x ,ar y)是CutMix后的图像
    • (odot)表示按元素相乘
    • (lambda)和Mixup中的一样,服从(eta(alpha,alpha))分布(论文中设置(alpha)为1)

    Mask的生成

    (M)的取值是随机生成一个bounding box来得到的,这个bbox的参数为(B=(r_x,r_y,r_w,r_y)),通过下面公式计算得到

    [r_x sim ext{Unif}(0,W) \ r_y sim ext{Unif}(0,H) \ r_w=Wsqrt{1-lambda} \ r_h=Hsqrt{1-lambda}]

    (M)这个矩阵的大小和图像一样,bbox内的值为0,其他值为1


    label的融合

    当前图片内容在融合后面积的占比决定了label的值,假设分别用两张图的30%和70%融合在一起,原始label分别是([1,0])([0,1]),则融合label为([0.3,0.7])

    从上面公式可以计算出生成的bbox大小为

    [r_w*r_h= Wsqrt{1-lambda}*Hsqrt{1-lambda} =WH(1-lambda)]

    bbox和原图的面积比例就为

    [WH(1-lambda)/(WH) = 1-lambda ]

    从公式(1)可以看出图A保留了bbox以外的部分,因此(y_A)的系数为(lambda)



    代码实现

    代码实现中有一些不同的是,生成bbox的中心点是在全图范围随机,如果中心点靠近图像边缘,那么bbox的面积和原图的比可能就不是(1-lambda)。因此这个面积比例是重新计算的。

    图像之间的对应关系是随机的,有可能对应到自己本身,就不会进行cutmix,多执行几次能看到效果。

    import matplotlib.pyplot as plt
    import numpy as np
    
    plt.rcParams['figure.figsize'] = [10, 10]
    
    import cv2
    
    def rand_bbox(size, lamb):
        """
        生成随机的bounding box
        :param size:
        :param lamb:
        :return:
        """
        W = size[0]
        H = size[1]
    
        # 得到一个bbox和原图的比例
        cut_ratio = np.sqrt(1.0 - lamb)
        cut_w = int(W * cut_ratio)
        cut_h = int(H * cut_ratio)
    
        # 得到bbox的中心点
        cx = np.random.randint(W)
        cy = np.random.randint(H)
    
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
    
        return bbx1, bby1, bbx2, bby2
    
    def cutmix(image_batch, image_batch_labels, alpha=1.0):
        # 决定bbox的大小,服从beta分布
        lam = np.random.beta(alpha, alpha)
    
        #  permutation: 如果输入x是一个整数,那么输出相当于打乱的range(x)
        rand_index = np.random.permutation(len(image_batch))
    
        # 对应公式中的y_a,y_b
        target_a = image_batch_labels
        target_b = image_batch_labels[rand_index]
    
        # 根据图像大小随机生成bbox
        bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
    
        image_batch_updated = image_batch.copy()
    
        # image_batch的维度分别是 batch x 宽 x 高 x 通道
        # 将所有图的bbox对应位置, 替换为其他任意一张图像
        # 第一个参数rand_index是一个list,可以根据这个list里索引去获得image_batch的图像,也就是将图片乱序的对应起来
        image_batch_updated[:, bbx1: bbx2, bby1:bby2, :] = image_batch[rand_index, bbx1:bbx2, bby1:bby2, :]
    
        # 计算 1 - bbox占整张图像面积的比例
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1)) / (image_batch.shape[1] * image_batch.shape[2])
        # 根据公式计算label
        label = target_a * lam + target_b * (1. - lam)
    
        return image_batch_updated, label
    
    if __name__ == '__main__':
        cat = cv2.cvtColor(cv2.imread("data/neko.png"), cv2.COLOR_BGR2RGB)
        dog = cv2.cvtColor(cv2.imread("data/inu.png"), cv2.COLOR_BGR2RGB)
    
        
        updated_img, label = cutmix(np.array([cat, dog]), np.array([[0, 1], [1, 0]]), 0.5)
        print(label)
    
        fig, axs = plt.subplots(nrows=1, ncols=2, squeeze=False)
        ax1 = axs[0, 0]![](https://img2020.cnblogs.com/blog/1621431/202108/1621431-20210814233143289-396153337.png)
    
        ax2 = axs[0, 1]
        ax1.imshow(updated_img[0])
        ax2.imshow(updated_img[1])
        plt.show()
    

    参考资料

    CutMix Augmentation in Python

    原论文

  • 相关阅读:
    C语言的特点与缺点
    C语言的特点与缺点
    HDU1234 开门人和关门人
    HDU1234 开门人和关门人
    B00014 C++实现的AC自动机
    B00014 C++实现的AC自动机
    HDU4716 A Computer Graphics Problem
    HDU4716 A Computer Graphics Problem
    I00029 C语言程序-打印九九乘法表
    I00029 C语言程序-打印九九乘法表
  • 原文地址:https://www.cnblogs.com/hikari-1994/p/15142301.html
Copyright © 2011-2022 走看看