zoukankan      html  css  js  c++  java
  • Pytorch学习笔记11----model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析、输出整个tensor的方法

    1.model.train()与model.eval()的用法

    看别人的面经时,浏览到一题,问的就是这个。自己刚接触pytorch时套用别人的框架,会在训练开始之前写上model.trian(),在测试时写上model.eval()。然后自己写的时候也就保留了这个习惯,没有去想其中原因。

    在经过一番查阅之后,总结如下:
    如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。

    联系Batch Normalization和Dropout的原理之后就不难理解为何要这么做了。

    2.Dropout

    dropout常常用于抑制过拟合,pytorch也提供了很方便的函数。但是经常不知道dropout的参数p是什么意思。在TensorFlow中p叫做keep_prob,就一直以为pytorch中的p应该就是保留节点数的比例,但是实验结果发现反了,实际上表示的是不保留节点数的比例。看下面的例子:

    a = torch.randn(10,1)
    >>> tensor([[ 0.0684],
            [-0.2395],
            [ 0.0785],
            [-0.3815],
            [-0.6080],
            [-0.1690],
            [ 1.0285],
            [ 1.1213],
            [ 0.5261],
            [ 1.1664]])

    P=0.5

    torch.nn.Dropout(0.5)(a)
    >>> tensor([[ 0.0000],  
            [-0.0000],  
            [ 0.0000],  
            [-0.7631],  
            [-0.0000],  
            [-0.0000],  
            [ 0.0000],  
            [ 0.0000],  
            [ 1.0521],  
            [ 2.3328]])

    数值上的变化: 2.3328=1.1664*2

    设置Dropout时,torch.nn.Dropout(0.5), 这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练

    将上一层数据减少一半传播

    3.relu,sigmiod,tanh激活函数

    在神经网络中原本输入输出都是线性关系,但现实中,许多的问题是非线性的(比如,房价问题中,房价不可能随着房子面积的增加一直线性增加),这个时候就神经网络的线性输出,再经过激励函数,便使得原本线性的关系变成非线性了,增强了神经网络的性能。

    常用的激活函数:relusigmoidtanhsoftmaxsoftplus

    import torch
    import torch.nn.functional as F
    from torch.autograd import Variable
    import matplotlib.pyplot as plt  # 为了方便找展示,用了matplotlib
    
    # 生成数据
    tensor_data = torch.linspace(-5, 5, 200)
    variable_data = Variable(tensor_data)
    np_data = variable_data.data.numpy()
    
    # 激活函数    (转为numpy是为了画图)
    relu_function = torch.relu(variable_data).data.numpy()
    sigmoid_function = torch.sigmoid(variable_data).data.numpy()
    tanh_function = torch.tanh(variable_data).data.numpy()
    softplus_function = F.softplus(variable_data).data.numpy()
    
    # 使用matplotlib作图
    plt.figure(1, figsize=(6, 6))
    
    plt.subplot(221)
    plt.plot(np_data, relu_function, c="green", label="relu")
    plt.ylim(-1, 5)
    plt.legend(loc="best")
    
    plt.subplot(222)
    plt.plot(np_data, sigmoid_function, c="green", label="sigmoid")
    plt.ylim(-0.2, 1.2)
    plt.legend(loc="best")
    
    plt.subplot(223)
    plt.plot(np_data, tanh_function, c="green", label="tanh")
    plt.ylim(-1.2, 1.2)
    plt.legend(loc="best")
    
    plt.subplot(224)
    plt.plot(np_data, softplus_function, c="green", label="softplus")
    plt.ylim(-0.2, 6)
    plt.legend(loc="best")
    
    plt.show()

    结果:

    4.nn.Linear浅析

    对输入数据进行线性变换

    查看源码:

    初始化部分

        def __init__(self, in_features, out_features, bias=True):
            super(Linear, self).__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = Parameter(torch.Tensor(out_features, in_features))
            if bias:
                self.bias = Parameter(torch.Tensor(out_features))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()

    需要实现的内容:

    参数说明:

    Args:
            in_features: size of each input sample  输入的二维张量的大小
            out_features: size of each output sample 输出的二维张量的大小
            bias: If set to ``False``, the layer will not learn an additive bias.
                Default: ``True``

    举个例子:

    >>> m = nn.Linear(20, 30)
    >>> input = torch.randn(128, 20)
    >>> output = m(input)
    >>> print(output.size())
    torch.Size([128, 30])

    张量的大小由 128 x 20 变成了 128 x 30

    执行的操作是:

    [128,20]×[20,30]=[128,30]

    5.输出整个tensor的方法

    torch.set_printoptions(profile="full")
    print(logit)  # prints the whole tensor
    torch.set_printoptions(profile="default")  # reset
    print(logit)  # prints the truncated tensor

    参考文献:

    https://blog.csdn.net/Qy1997/article/details/106455717

    https://www.cnblogs.com/marsggbo/p/10592643.html

  • 相关阅读:
    专业的户外直播视频传输系统是如何搭建起来的?通过GB28181协议建立的户外直播方案
    Go-注释
    语言的动态性和静态性
    程序&命名-执行环境
    Go-错误栈信息
    Mongo-文档主键-ObjectId
    Mongo-关系型VS非关系型
    数据-CRUD
    Mongo基本操作
    mongo环境搭建
  • 原文地址:https://www.cnblogs.com/luckyplj/p/13424561.html
Copyright © 2011-2022 走看看