zoukankan      html  css  js  c++  java
  • MWCNN中使用的haar小波变换 pytorch

    1.原理

    小波变换的计算方法:

    1)一维信号:

    例如:有a=[5,7,6,8]四个数,并使用b[4]数组来保存结果.

            则一级Haar小波变换的结果为:

            b[0]=(a[0]+a[1])/2,                       b[2]=(a[0]-a[1])/2

            b[1]=(a[2]+a[3])/2,                       b[3]=(a[2]-a[3])/2

    ⚠️计算差均值时也有看见a[1]-a[0]的,只要保持一致应该都可以

    由此可知,Haar变换采用的原理是:

    A)低频采用和均值,即b[0]和b[1],和均值中均值存储了图像的整体信息

    B)高频采用差均值,即b[2]和b[3],用于记录图像的细节信息,这样在重构时能够恢复图像的全部信息

     因此上面的例子中

    b[0] = (5+7)/2 = 6 , b[1] = (6+8)/2 = 6, b[2] = (5-7)/2 = -1, b[3] = (6-8)/2 = -1 

    如果要继续进行多级的小波变换:

      如上图可见是对低频的信息继续进行haar小波变换

    2)二维

    对于二维haar小波,我们通常一次分解形成了整体图像,水平细节,垂直细节,对角细节。首先我们按照一维haar小波分解的原理,按照行顺序对行进行处理,然后按照列顺序对行处理结果进行同样的处理

    用图像表述如图所示:图中a表示原图,图b表示经过一级小波变换的结果,h1 表示水平反向的细节,v1 表示竖直方向的细节,c1表示对角线方向的细节,b表示下2采样的图像。图c中表示继续进行了三次Haar小波变换的结果:

    详细过程经过下面的代码来解释

    2.实现

    1)

    代码:https://github.com/lpj0/MWCNN_PyTorch/blob/master/model/common.py:

    原图为:

    中间有个问题,就是逆向重构的时候发现并没有成功,得到的结果是:

    于是对操作的数据进行了一番输出:

    #coding:utf-8
    import torch.nn as nn
    import torch
    
    def dwt_init(x):
        print('-------------- origin ---------------')
        print(x[:,0,:,:])
        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2
        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]
        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4
        print('---------------- LL ------------------')
        print(x_LL[:,0,:,:])
        print()
        print('---------------- HL ------------------')
        print(x_HL[:,0,:,:])
        print()
        print('---------------- LH ------------------')
        print(x_LH[:,0,:,:])
        print()
        print('---------------- HH ------------------')
        print(x_HH[:,0,:,:])
        return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
    
    
    # 使用哈尔 haar 小波变换来实现二维逆向离散小波
    def iwt_init(x):
        r = 2
        in_batch, in_channel, in_height, in_width = x.size()
        # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
        out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
        # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
        x1 = x[:, 0:out_channel, :, :] / 2
        print('-------------- enter iwt ---------------')
        print(x1[:,0,:,:])
        x2 = x[:, out_channel:out_channel * 2, :, :] / 2
        x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
        x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
        # print(x1.shape) #torch.Size([1, 3, 56, 56])
        # print(x2.shape) #torch.Size([1, 3, 56, 56])
        # print(x3.shape) #torch.Size([1, 3, 56, 56])
        # print(x4.shape) #torch.Size([1, 3, 56, 56])
        # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
        h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()
    
        h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
        print('-------------- back ---------------')
        print(h[:,0,:,:])
        return h
    
    
    # 二维离散小波
    class DWT(nn.Module):
        def __init__(self):
            super(DWT, self).__init__()
            self.requires_grad = False  # 信号处理,非卷积运算,不需要进行梯度求导
    
        def forward(self, x):
            return dwt_init(x)
    
    
    # 逆向二维离散小波
    class IWT(nn.Module):
        def __init__(self):
            super(IWT, self).__init__()
            self.requires_grad = False
    
        def forward(self, x):
            return iwt_init(x)
    
    if __name__ == '__main__':
        import os, cv2, torchvision
        from PIL import Image
        import numpy as np
        from torchvision import transforms as trans
    
        img = Image.open('./1.jpg')
        transform = trans.Compose([
            trans.ToTensor()
        ])
        img = transform(img).unsqueeze(0)
        dwt = DWT()
        change_img_tensor = dwt(img)
        # print(change_img_tensor.shape) #torch.Size([1, 12, 56, 56])
        print('-------------- after dwt ---------------')
        print(change_img_tensor[:,0,:,:])
    
        for i in range(change_img_tensor.size(1)//3):
            torchvision.utils.save_image(change_img_tensor[:,i*3:i*3+3:,:], os.path.join('./', 'change_{}.jpg'.format(i)))
             
        iwt = IWT()
        back_img_tensor = iwt(change_img_tensor)
        print(back_img_tensor.shape)
    
        torchvision.utils.save_image(back_img_tensor, 'back.jpg')

    返回:

    (deeplearning) bogon:learning user$ python delete.py 
    -------------- origin ---------------
    tensor([[[0.9020, 0.9882, 0.9216,  ..., 0.7176, 0.7843, 0.8431],
             [0.8941, 0.9608, 0.9255,  ..., 0.7490, 0.7569, 0.7490],
             [0.8980, 0.9333, 0.8863,  ..., 0.6941, 0.7333, 0.7608],
             ...,
             [0.1373, 0.1373, 0.1451,  ..., 0.9529, 0.9686, 0.9804],
             [0.1451, 0.1451, 0.1490,  ..., 0.9294, 0.9373, 0.9569],
             [0.1373, 0.1412, 0.1451,  ..., 0.9137, 0.9020, 0.9176]]])
    ---------------- LL ------------------
    tensor([[[1.8725, 1.8569, 1.9314,  ..., 1.4176, 1.4510, 1.5667],
             [1.8294, 1.7588, 1.5118,  ..., 1.4490, 1.4314, 1.5137],
             [1.9039, 1.6059, 1.0412,  ..., 1.2765, 1.4588, 1.5039],
             ...,
             [0.3078, 0.3216, 0.3490,  ..., 1.8784, 1.8333, 1.7627],
             [0.2647, 0.3059, 0.3784,  ..., 1.7510, 1.8725, 1.9137],
             [0.2843, 0.3020, 0.3922,  ..., 1.8941, 1.8549, 1.8569]]])
    
    ---------------- HL ------------------
    tensor([[[ 0.0765,  0.0098,  0.0059,  ...,  0.0098,  0.0157,  0.0255],
             [ 0.0294, -0.0294, -0.0922,  ..., -0.0098,  0.0078,  0.0235],
             [-0.0412, -0.1588, -0.0725,  ...,  0.0569,  0.0314, -0.0059],
             ...,
             [-0.0137,  0.0196,  0.0039,  ...,  0.0275, -0.0412,  0.0098],
             [-0.0020,  0.0196,  0.0176,  ...,  0.0216,  0.0216,  0.0039],
             [ 0.0020,  0.0078,  0.0314,  ...,  0.0039, -0.0118,  0.0176]]])
    
    ---------------- LH ------------------
    tensor([[[-0.0176,  0.0098, -0.0176,  ..., -0.0137,  0.0392, -0.0608],
             [-0.0020, -0.0098, -0.1510,  ...,  0.1078,  0.0745,  0.0196],
             [ 0.0176, -0.1392, -0.0882,  ..., -0.0059,  0.0588,  0.0725],
             ...,
             [-0.0255, -0.0078,  0.0118,  ..., -0.0431,  0.0216, -0.0098],
             [ 0.0098,  0.0000,  0.0020,  ...,  0.0020, -0.0020,  0.0353],
             [-0.0059, -0.0039,  0.0039,  ...,  0.0549,  0.0039, -0.0373]]])
    
    ---------------- HH ------------------
    tensor([[[-0.0098,  0.0059, -0.0216,  ...,  0.0216, -0.0078, -0.0333],
             [-0.0059, -0.0255, -0.0294,  ...,  0.0137, -0.0235, -0.0039],
             [-0.0373, -0.0373,  0.0608,  ..., -0.0020,  0.0196, -0.0098],
             ...,
             [ 0.0059,  0.0039,  0.0039,  ...,  0.0118,  0.0137, -0.0137],
             [ 0.0020, -0.0039,  0.0020,  ..., -0.0098,  0.0137,  0.0078],
             [ 0.0020,  0.0000,  0.0039,  ...,  0.0000, -0.0196, -0.0020]]])
    -------------- after dwt ---------------
    tensor([[[1.8725, 1.8569, 1.9314,  ..., 1.4176, 1.4510, 1.5667],
             [1.8294, 1.7588, 1.5118,  ..., 1.4490, 1.4314, 1.5137],
             [1.9039, 1.6059, 1.0412,  ..., 1.2765, 1.4588, 1.5039],
             ...,
             [0.3078, 0.3216, 0.3490,  ..., 1.8784, 1.8333, 1.7627],
             [0.2647, 0.3059, 0.3784,  ..., 1.7510, 1.8725, 1.9137],
             [0.2843, 0.3020, 0.3922,  ..., 1.8941, 1.8549, 1.8569]]])
    -------------- enter iwt ---------------
    tensor([[[127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
             [127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
             [127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
             ...,
             [ 39.5000,  41.2500,  44.7500,  ..., 127.5000, 127.5000, 127.5000],
             [ 34.0000,  39.2500,  48.5000,  ..., 127.5000, 127.5000, 127.5000],
             [ 36.5000,  38.7500,  50.2500,  ..., 127.5000, 127.5000, 127.5000]]])
    -------------- back ---------------
    tensor([[[117.5000, 137.5000, 125.5000,  ..., 124.5000, 124.0000, 131.0000],
             [117.5000, 137.5000, 126.5000,  ..., 135.0000, 124.0000, 131.0000],
             [123.5000, 131.5000, 127.5000,  ..., 119.0000, 121.5000, 128.0000],
             ...,
             [ 35.0000,  36.0000,  36.7500,  ..., 132.5000, 130.2500, 134.2500],
             [ 36.5000,  36.5000,  37.7500,  ..., 126.7500, 125.0000, 130.0000],
             [ 35.5000,  37.5000,  37.2500,  ..., 128.2500, 125.0000, 130.0000]]])
    torch.Size([1, 3, 112, 112])

    发现输入iwt的结果变化了,突然想起来torchvision.utils.save_image函数是会对数据进行处理的

    解决办法就是调整下顺序即可

    重新运行一遍:

    #coding:utf-8
    import torch.nn as nn
    import torch
    
    def dwt_init(x):
        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2
        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]
        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4
    
        return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
    
    
    # 使用哈尔 haar 小波变换来实现二维逆向离散小波
    def iwt_init(x):
        r = 2
        in_batch, in_channel, in_height, in_width = x.size()
        # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
        out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
        # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
        x1 = x[:, 0:out_channel, :, :] / 2
        x2 = x[:, out_channel:out_channel * 2, :, :] / 2
        x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
        x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
        # print(x1.shape) #torch.Size([1, 3, 56, 56])
        # print(x2.shape) #torch.Size([1, 3, 56, 56])
        # print(x3.shape) #torch.Size([1, 3, 56, 56])
        # print(x4.shape) #torch.Size([1, 3, 56, 56])
        # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
        h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()
    
        h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
        return h
    
    
    # 二维离散小波
    class DWT(nn.Module):
        def __init__(self):
            super(DWT, self).__init__()
            self.requires_grad = False  # 信号处理,非卷积运算,不需要进行梯度求导
    
        def forward(self, x):
            return dwt_init(x)
    
    
    # 逆向二维离散小波
    class IWT(nn.Module):
        def __init__(self):
            super(IWT, self).__init__()
            self.requires_grad = False
    
        def forward(self, x):
            return iwt_init(x)
    
    if __name__ == '__main__':
        import os, cv2, torchvision
        from PIL import Image
        import numpy as np
        from torchvision import transforms as trans
        # img = cv2.imread('./1.jpg')
        # print(img.shape)
        # img = Image.fromarray(img.astype(np.uint8))
        img = Image.open('./1.jpg')
        transform = trans.Compose([
            trans.ToTensor()
        ])
        img = transform(img).unsqueeze(0)
        dwt = DWT()
        change_img_tensor = dwt(img)
        iwt = IWT()
        back_img_tensor = iwt(change_img_tensor)
        print(back_img_tensor.shape)
        # print(change_img_tensor.shape) #torch.Size([1, 12, 56, 56])
        
        
        #合并成一张4格的图
        h = torch.zeros([4,3,change_img_tensor.size(2),change_img_tensor.size(2)]).float()
    
        
        for i in range(change_img_tensor.size(1)//3):
            h[i,:,:,:] = change_img_tensor[:,i*3:i*3+3:,:]
            #分别保存为一个图片
            torchvision.utils.save_image(change_img_tensor[:,i*3:i*3+3:,:], os.path.join('./', 'change_{}.jpg'.format(i)))
        
                
        change_img_grid = torchvision.utils.make_grid(h, 2) #一行2张图片    
        torchvision.utils.save_image(change_img_grid, 'change_img_grid.jpg')
    
        torchvision.utils.save_image(back_img_tensor, 'back.jpg')

    小波变换后的结果为:

    重构的图为:

    2)对代码进行解释

    1》dwt

    def dwt_init(x):
        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2
        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]
        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4
    
        return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)

    首先:

        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2

    将矩阵分为偶数行和奇数行,并将所有值都除以2,这样后面只要进行求和和求差即可,因为已经求均值了

    然后下面的:

        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]

    就是分为偶数列和奇数列,假设矩阵为6*6大小,那么就将该矩阵分成了4个3*3大小的x1、x2、x3和x4,如下图所示:

    那么接下来在进行的计算就是进行行、列的和、差变换了:

        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4

    再用一张图说明:

    2》iwt

    def iwt_init(x):
        r = 2
        in_batch, in_channel, in_height, in_width = x.size()
        # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
        out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
        # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
        x1 = x[:, 0:out_channel, :, :] / 2
        x2 = x[:, out_channel:out_channel * 2, :, :] / 2
        x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
        x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
        # print(x1.shape) #torch.Size([1, 3, 56, 56])
        # print(x2.shape) #torch.Size([1, 3, 56, 56])
        # print(x3.shape) #torch.Size([1, 3, 56, 56])
        # print(x4.shape) #torch.Size([1, 3, 56, 56])
        # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
        h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()
    
        h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
        return h

    在这里的行x1=x_LL/2,  x2=x_HL/2,  x3=x_LH/2,  x4=x_HH/2

    所以我们想重构,其实就是从这些值中恢复dwt中的x1,x2,x3,x4,分别放到h对应的位置变为原来的矩阵,如x1对应的是h[:, :, 0::2, 0::2],如下图所示:

    这就是重构的方法

    过程中遇到的一点问题pytorch图像处理的问题

  • 相关阅读:
    Redis其他命令
    Redis 键(key)
    Redis数据类型命令
    配置命令 CONFIG
    Redis简介
    idea破解
    jenkins+allure+testng
    Jenkins
    docker镜像
    docker命令
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/12524515.html
Copyright © 2011-2022 走看看