zoukankan      html  css  js  c++  java
  • 深度学习面试题35:RNN梯度消失问题(vanishing gradient)

    目录

      梯度消失原因之一:激活函数

      梯度消失原因之二:初始化权重

      不同损失函数下RNN的梯度消失程度对比

      实践中遇到梯度消失怎么办?

      参考资料


    在实践过程中,RNN的一个缺点是在训练的过程中容易梯度消失。

    梯度消失原因之一:激活函数

     

    sigmod的导函数峰值为0.25,由于反向传播的距离越长,连乘的小数越多,所以sigmod一定会产生梯度消失,并且很严重。但是因为tanh的导函数峰值为1,所以tanh造成的梯度消失的程度比sigmod更小。这一段的结论应该是比较简单易懂的。

    那么Relu类的激活函数呢?

    先给出我的理解,Relu不是解决梯度消失的问题充分条件。分析如下:

    虽然Relu的导函数在x>0的区域等于1,好像再怎么连乘都不会让累计梯度变小。但是这里忽略了在x<=0的区域导函数是等于0的情况,那就意味着和他连接的浅层边上的梯度直接就归0了,那就是梯度消失了。

    我们画图来分析上面的问题,假设有如下网络

    这里假设要计算w1的梯度值,其导函数为

    由链式法则可知损失函数对浅层网络参数的导数应该是多个梯度值的累和,由Relu的导函数的性质可知,z3,z4,z1如果小于等于0,都会影响到w1的梯度值,且z1的影响程度最大,一旦z1小于等于0,那么w1的梯度就直接归0,这就会造成梯度消失。

     返回目录

     

    梯度消失原因之二:初始化权重

     考略下图,如果初始化的模型参数比较大,不要说梯度消失了,可能梯度都爆炸了;但是如果权重参数太小,那也会引起梯度消失。

     返回目录

     

    不同损失函数下RNN的梯度消失程度对比

    从下图可以看到,在RNN中,随着网络层数的变浅参数梯度越来越小,即梯度很难传到浅层的位置

    对应代码

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    
    torch.manual_seed(5)
    import matplotlib.pyplot as plt
    import matplotlib.pyplot as plt
    
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, output_size, active_function):
            super(RNN, self).__init__()
    
            self.hidden_size = hidden_size
            self.active_function = active_function
            self.i2h0 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h1 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h2 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h3 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h4 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h5 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h6 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h7 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h8 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2h9 = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
            self.i2o = nn.Linear(hidden_size, output_size, bias=False)
    
        def forward(self, input, hidden):
            combined = torch.cat((hidden, input[0]), 1)
    
            hidden = self.active_function(self.i2h0(combined))
            combined = torch.cat((hidden, input[1]), 1)
            hidden = self.active_function(self.i2h1(combined))
            combined = torch.cat((hidden, input[2]), 1)
            hidden = self.active_function(self.i2h2(combined))
            combined = torch.cat((hidden, input[3]), 1)
            hidden = self.active_function(self.i2h3(combined))
            combined = torch.cat((hidden, input[4]), 1)
            hidden = self.active_function(self.i2h4(combined))
            combined = torch.cat((hidden, input[5]), 1)
            hidden = self.active_function(self.i2h5(combined))
            combined = torch.cat((hidden, input[6]), 1)
            hidden = self.active_function(self.i2h6(combined))
            combined = torch.cat((hidden, input[7]), 1)
            hidden = self.active_function(self.i2h7(combined))
            combined = torch.cat((hidden, input[8]), 1)
            hidden = self.active_function(self.i2h8(combined))
            combined = torch.cat((hidden, input[9]), 1)
            hidden = self.active_function(self.i2h9(combined))
            output = self.active_function(self.i2o(hidden))
    
            return output, hidden
    
        def initHidden(self):
            # return torch.zeros(1, self.hidden_size)
            w = torch.empty(1, self.hidden_size)
            nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
            return w
    
    
    def train(category_tensor, input_tensor):
        hidden = rnn.initHidden()
        rnn.zero_grad()
    
        output, hidden = rnn(input_tensor, hidden)
    
        loss = criterion(output, category_tensor)
        loss.backward()
    
        # Add parameters' gradients to their values, multiplied by learning rate
        lst_params = list(rnn.parameters())[:10]  # 只获取i2h的参数
        lst_x = []
        lst_y = []
        for i, p in enumerate(lst_params):
            # print("梯度值", p.grad.data)
            grad_abs = np.abs(np.array(p.grad.data))
            # np_greater_than_0 = grad_abs.reshape(-1)
            np_greater_than_0 = grad_abs
            # np_greater_than_0 = np_greater_than_0[np_greater_than_0 > 0]
            print(np.max(np_greater_than_0))
            grad_abs_mean_log = np.log10(np_greater_than_0.mean())
            grad_abs_var_log = np.log10(np_greater_than_0.var())
            # print("倒数第{}层的梯度张量绝对值的均值取对数为{},方差取对数为{}".format(i + 1, grad_abs_mean_log, grad_abs_var_log))
            lst_x.append(i + 1)
            lst_y.append(grad_abs_mean_log)
            p.data.add_(p.grad.data, alpha=-learning_rate)
    
        return output, loss.item(), lst_x, lst_y
    
    
    if __name__ == '__main__':
        input_tensor = torch.randn(10, 1, 100)
        input_size = input_tensor.shape[-1]
        hidden_size = 200
        output_size = 2
        # rnn0 = RNN(input_size, hidden_size, output_size, torch.relu)
        # init_weight = rnn0.i2h0._parameters["weight"].data
        init_weight = torch.randn(200, 300)/15
    
        active_function = torch.relu
        active_function = torch.tanh
        active_function = torch.sigmoid
        rnn = RNN(input_size, hidden_size, output_size, active_function)
        # init_weight = rnn.i2h0._parameters["weight"].data
        rnn.i2h1._parameters["weight"].data = init_weight
        rnn.i2h2._parameters["weight"].data = init_weight
        rnn.i2h3._parameters["weight"].data = init_weight
        rnn.i2h4._parameters["weight"].data = init_weight
        rnn.i2h5._parameters["weight"].data = init_weight
        rnn.i2h6._parameters["weight"].data = init_weight
        rnn.i2h7._parameters["weight"].data = init_weight
        rnn.i2h8._parameters["weight"].data = init_weight
        rnn.i2h9._parameters["weight"].data = init_weight
    
        criterion = nn.CrossEntropyLoss()
        learning_rate = 1
        n_iters = 1
        all_losses = []
        lst_x = []
        lst_y = []
        for iter in range(1, n_iters + 1):
            category_tensor = torch.tensor([0])  # 第0类,哑编码:[1, 0]
            output, loss, lst_x, lst_y = train(category_tensor, input_tensor)
            print("迭代次数", iter, output, loss)
    
        print(lst_y)
        lst_sigmod = [-9.315002, -8.481065, -7.7988734, -7.133273, -6.412653, -5.703941, -5.020198, -4.441074, -3.7632055, -3.1263535]
        lst_tanh = [-3.5717661, -3.4407198, -3.1482387, -2.968598, -2.7806234, -2.58508, -2.4179213, -2.3331132, -2.164275, -2.0336704]
        lst_relu = [-4.169364, -4.0102725, -3.6641762, -3.505077, -3.2865758, -3.089403, -2.8985455, -2.762998, -2.503199, -2.368149]
    
        plt.plot(lst_x, lst_sigmod, label="sigmod")
        plt.plot(lst_x, lst_tanh, label="tanh")
        plt.plot(lst_x, lst_relu, label="relu")
        plt.xlabel("第i个时间步")
        plt.ylabel("梯度张量绝对值的均值取对数")
        plt.title("调研:不同激活函数下梯度消失的程度")
        plt.legend(loc="lower left")
        plt.show()
    View Code

     返回目录

     

    实践中遇到梯度消失怎么办?

    注意:梯度消失不一定就代表网络就不能学习!如果不是太深的网络(针对RNN来讲就是时间步较少),即使存在梯度消失的问题,还是可以训练的,只不过时间会久一些。

    在实践过程中我们可以选用Relu系列的激活函数(毕竟导函数有等于1的区域,就意味着每次迭代都能让一些较浅层的参数得到较大的梯度值,可以这样理解,他的通透性比较强,但是每一轮更新只有一些位置都能通过去,循环起来效果还是可以的),并且合理的初始化模型参数(不能太小,但也不能达到让梯度爆炸的程度)

    在训练深层网络的时候,直接训练肯定没戏,浅层梯度基本都等于0了,CNN可以使用带有跳跃连接的模块,比如ResNet;RNN可以使用LSTM的结构,这个在以后的内容中再细谈。

     返回目录

     

    参考资料

    https://zhuanlan.zhihu.com/p/33006526?utm_source=wechat_session&utm_medium=social&utm_oi=829090756730970112&utm_content=first

    https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial

     返回目录

     

  • 相关阅读:
    <转>WCF中出现死锁或者超时
    无连接服务器与面向连接的服务器
    Linux系统调用
    vim文本删除方法 Linux
    深入了解C指针
    linux下c语言实现双进程运行
    *p++、(*p)++、*++p、++*p 的区别
    快速了解yuv4:4:4 yuv4:2:2 yuv 4:1:1 yuv 4:2:0四种YUV格式区别
    文件通过svn updata更新不到,并且svn st显示被删除的解决办法
    [非常重要的总结] Linux C相关函数
  • 原文地址:https://www.cnblogs.com/itmorn/p/13285206.html
Copyright © 2011-2022 走看看