zoukankan      html  css  js  c++  java
  • [Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint

    Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

    @article{cyr2019robust,
    title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},
    author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},
    journal={arXiv: Learning},
    year={2019}}

    这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

    主要内容

    在这里插入图片描述

    [ ag{6} arg min_{xi^L xi^H} sum_{k=1}^K epsilon_k |mathcal{L}_k[u] - sum_i xi_i^L mathcal{L}_k [Phi_i(x, xi^H)]|^2_{ell_2(mathcal{X}_k)}. ]

    LSGD

    固定(xi^H, mathcal{X}_k), 并令(epsilon_k=1), 则问题(6)退化为一个最小二乘问题

    [arg min_{xi^L} |Axi^L -b|^2_{ell_2 (mathcal{X})}, ]

    其中(b_i = mathcal{L}[u](x_i)), (A_{ij}=mathcal{L} [Phi_j (x_i, xi^H)]), (x_i in mathcal{X}, i=1,ldots, N, j=1, ldots, w).

    所以算法如下

    在这里插入图片描述

    Box 初始化

    该算法期望使得feature-rich,但是我不知道这个rich从何而来.

    假设第(l)层的输入为(x in mathbb{R}^{d_1}), 输出为(y in mathbb{R}^{d_2}), 则该层的权重矩阵(W in mathbb{R}^{d_2 imes d_1}). 我们逐行地定义(W):

    1. 采样(p), (psim U[0 ,1]^{d_1});
    2. 采样(n), (n sim mathcal{N}(0,I_{d_1})), 并令(n=n/|n|);
    3. 求参数(k)使得

    [max_{x in [0, 1]^{d_1}} sigma(k(x-p) cdot n)=1. ]

    1. (W)(i)(w_i=kn^T), (b_i=-kp cdot n).

    其中(sigma)表示激活函数, 文中指的是ReLU.
    求解参数(k):

    1. (p_{max} = max (0, mathrm{sign}(n)));
    2. (k=frac{1}{(p_{max}-p) cdot n})

    (k)即为所需(k), 只需证明(p_{max})是最大化

    [(x - p)cdot n, quad x in [0,1]^{d_1} ]

    的解. 最大化上式, 可以分解为

    [max_{x_i in [0, 1]} x_in_i, ]

    (x_i = max(0, mathrm{sign}(n_i))).

    这个初始化有什么好处呢, 可以发现, 输入(x in[0,1]^{d_1})满足, 则输出(y in [0, 1]^{d_2}), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.

    在这里插入图片描述
    如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

    下面是文中列出的算法(与这里的符号有一点点不同, 另外(b)作者应该是遗漏了负号).

    在这里插入图片描述

    Box for Resnet

    因为Resnet特殊的结构,

    [y=(W+I)x+b. ]

    假设(x in [0,m]^{d_1}), 则:

    1. 采样(p), (psim U[0 ,m]^{d_1});
    2. 采样(n), (n sim mathcal{N}(0,I_{d_1})), 并令(n=n/|n|);
    3. 求参数(k)使得

    [max_{x in [0, m]^{d_1}} sigma(k(x-p) cdot n)=delta m. ]

    1. (W)(i)(w_i=kn^T), (b_i=-kp cdot n).

    [k=frac{delta m}{(mp_{max}-p) cdot n}. ]

    若第一层输入(x_i in [0,1]), 去(delta=1/L), 其中(L)为总的层数, 则

    [[0,1] ightarrow [0,1+frac{1}{L}] ightarrow [0,(1+frac{1}{L})^2] ightarrow cdots ]

    在这里插入图片描述

    代码

    
    
    
    '''
    initialization.py
    '''
    import torch
    import torch.nn as nn
    import warnings
    
    
    
    
    
    def generate(size, m, delta):
        p = torch.rand(size) * m
        n = torch.randn(size)
        temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
        n = temp * n
        pmax = nn.functional.relu(torch.sign(n)) * m
        temp = (pmax - p) * n
        k = (m * delta) / temp.sum(dim=1, keepdim=True)
        w = k * n
        b = -(w * p).sum(dim=1)
        return w, b
    
    def box_init(module, m=1, delta=1):
        if isinstance(module, nn.Linear):
            w, b = generate(module.weight.shape, m, delta)
            try:
                module.weight.data = w
                module.bias.data = b
            except AttributeError as e:
                s = "Error: 
    " + str(e) + "
     stops the initialization" 
                                           " for this module: {}".format(module)
                warnings.warn(s)
    
        elif isinstance(module, nn.Conv2d):
            outc, inc, h, w = module.weight.size()
            w, b = generate((outc, inc * h * w), m, delta)
            try:
                module.weight.data = w.reshape(module.weight.size())
                module.bias.data = b
            except AttributeError as e:
                s = "Error: 
    " + str(e) + "
     stops the initialization" 
                                           " for this module: {}".format(module)
                warnings.warn(s)
    
        else:
            pass
    
    
    
    
    
    """config.py"""
    
    nums = 10
    layers = 6
    method = "kaiming"  #box/xavier/kaiming
    net = "Net"  #Net/ResNet
    
    
    
    
    
    
    
    """
    测试
    """
    
    
    
    import torch
    import torch.nn as nn
    import config
    from initialization import box_init
    
    
    
    class Net(nn.Module):
    
        def __init__(self, l):
            super(Net, self).__init__()
    
            self.linears = []
            for i in range(l):
                name = "linear" + str(i)
                self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
                                                     nn.ReLU()))
                self.linears.append(self.__getattr__(name))
            if config.method == 'box':
                self.box_init()
            elif config.method == "xavier":
                self.xavier_init()
            else:
                self.kaiming_init()
    
        def box_init(self):
            for module in self.modules():
                box_init(module)
    
        def xavier_init(self):
            for module in self.modules():
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    nn.init.xavier_normal_(module.weight)
    
        def kaiming_init(self):
            for module in self.modules():
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    nn.init.kaiming_normal_(module.weight)
    
        def forward(self, x):
            out = []
            temp = x
            for linear in self.linears:
                temp = linear(temp)
                out.append(temp)
            return out
    
    
    
    class ResNet(nn.Module):
    
        def __init__(self, l):
            super(ResNet, self).__init__()
    
            self.linears = []
            for i in range(l):
                name = "linear" + str(i)
                self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
                                                     nn.ReLU()))
                self.linears.append(self.__getattr__(name))
            if config.method == 'box':
                self.box_init(l)
            elif config.method == "xavier":
                self.xavier_init()
            else:
                self.kaiming_init()
    
        def box_init(self, layers):
            delta = 1 / layers
            m = 1. + delta
            l = 0
            for module in self.modules():
                if isinstance(module, (nn.Linear)):
                    if l == 0:
                        box_init(module, 1, 1)
                    else:
                        box_init(module, m ** l, delta)
                    l += 1
    
        def xavier_init(self):
            for module in self.modules():
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    nn.init.xavier_normal_(module.weight)
    
        def kaiming_init(self):
            for module in self.modules():
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    nn.init.kaiming_normal_(module.weight)
    
        def forward(self, x):
            out = []
            temp = x
            for linear in self.linears:
                temp = linear(temp) + temp
                out.append(temp)
            return out
    
    
    if config.net == "Net":
        net = Net(config.layers)
    else:
        net = ResNet(config.layers)
    
    x = torch.linspace(0, 1, config.nums)
    y = torch.linspace(0, 1, config.nums)
    
    grid_x, grid_y = torch.meshgrid(x, y)
    
    x = grid_x.flatten()
    y = grid_y.flatten()
    data = torch.stack((x, y), dim=1)
    outs = net(data)
    
    
    import  matplotlib.pyplot as plt
    
    
    def axplot(x, y, ax):
        x = x.detach().numpy()
        y = y.detach().numpy()
        ax.scatter(x, y)
    
    def plot(x, y, outs):
        fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
        axs[0].scatter(x, y)
        axs[0].set(title="layer0")
        for i in range(config.layers):
            ax = axs[i+1]
            out = outs[i]
            x = out[:, 0]
            y = out[:, 1]
            axplot(x, y, ax)
            ax.set(title="layer"+str(i+1))
        plt.tight_layout()
        plt.savefig("C:/Users/pkavs/Desktop/fig.png")
        #plt.show()
    plot(x, y, outs)
    
    
    
    
    
    
    
    
  • 相关阅读:
    【BZOJ1006】神奇的国度(弦图)
    弦图
    【BZOJ2946】公共串(后缀数组)
    【POJ1743】Musical Theme(后缀数组)
    JAVA和Tomcat运维整理
    linux shell 之if-------用if做判断
    Linux curl命令详解
    Intel HEX文件解析
    Linux bridge-utils tunctl 使用
    怎样查询锁表的SQL
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/12760783.html
Copyright © 2011-2022 走看看