zoukankan      html  css  js  c++  java
  • 『PyTorch x TensorFlow』第八弹_基本nn.Module层函数

    『TensorFlow』网络操作API_上  

    『TensorFlow』网络操作API_中

    『TensorFlow』网络操作API_下

    之前也说过,tf 和 t 的层本质区别就是 tf 的是层函数,调用即可,t 的是类,需要初始化后再调用实例(实例都是callable的)

    卷积

    tensorflow.nn.conv2d

    import tensorflow as tf
    
    sess = tf.Session()
    input = tf.Variable(tf.random_normal([1,3,3,5]))
    
    # 卷积核尺寸*2,输入通道,输出通道,
    filter = tf.Variable(tf.random_normal([1,1,5,1])) # 《-----卷积核初始化
    
    conv = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
    
    sess.run(tf.global_variables_initializer())
    print(sess.run(conv).shape)
    
    (1, 3, 3, 1)
    

    torch.nn.Conv2d

    troch集成了初始化核的部分,所以自行初始化时需要直接修改变量的data

    本篇很多例子中都对module的属性直接操作,其大多数是可学习参数,一般会随着学习的进行而不断改变。实际使用中除非需要使用特殊的初始化,应尽量不要直接修改这些参数。

    import torch as t
    input = t.normal(means=t.zeros([1,5,3,3]), std=t.Tensor([0.1]).expand([1,5,3,3]))
    input = t.autograd.Variable(input)
    
    # 输入通道,输出通道,卷积核尺寸,步长,是否偏执
    conv = t.nn.Conv2d(5, 1, (1, 1), 1, bias=False)
    
    # 输出通道,输入通道,卷积核尺寸*2
    print([n for n,p in conv.named_parameters()])
    conv.weight.data = t.ones([1,5,1,1]) # 《-----卷积核初始化,可有可无
    
    out = conv(input)
    print(out.size())
    
    ['weight']
    torch.Size([1, 1, 3, 3])
    

    池化

    tensorflow.nn.avg_pool

    torch.nn.AvgPool2d

    可以验证没有学习参数

    pool = nn.AvgPool2d(2,2)
    list(pool.parameters())
    
    []

    线性

    torch.nn.Linear

    # 输入 batch_size=2,维度3
    input = V(t.randn(2, 3))
    linear = nn.Linear(3, 4)
    h = linear(input)
    print(h)
    
    Variable containing:
    -1.4189 -0.2045  1.2143 -1.5404
     0.8471 -0.3154 -0.5855  0.0153
    [torch.FloatTensor of size 2x4]

    BatchNorm

    『TensorFlow』批处理类

    torch.nn.BatchNorm1d

    BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。

    # 4 channel,初始化标准差为4,均值为0
    bn = nn.BatchNorm1d(4)
    print([n for n,p in bn.named_parameters()])
    bn.weight.data = t.ones(4) * 4
    bn.bias.data = t.zeros(4)
    
    bn_out = bn(h)
    # 注意输出的均值和方差
    # 方差是标准差的平方,计算无偏方差分母会减1
    # 使用unbiased=False 分母不减1
    bn_out.size(), bn_out.mean(0), bn_out.var(0, unbiased=False)
    
    ['weight', 'bias']
    
    (torch.Size([2, 4]), 

    Variable containing: 1.00000e-06 * 0.0000 -1.0729 0.0000 0.1192 [torch.FloatTensor of size 4],

    Variable containing: 15.9999 15.9481 15.9998 15.9997 [torch.FloatTensor of size 4])

    Dropout

    tensorflow.nn.dropout

    torch.nn.Dropout

    dropout层,用来防止过拟合,同样分为1D、2D和3D。 下面通过例子来说明它们的使用。

     # 每个元素以0.5的概率舍弃
    dropout = nn.Dropout(0.5)
    o = dropout(bn_out)
    o # 有一半左右的数变为0
    
    Variable containing:
    -7.9895 -7.9931  7.9991  7.9973
     0.0000  0.0000 -7.9991 -7.9973
    [torch.FloatTensor of size 2x4]

    激活函数

    PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档^3,这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:

    relu = nn.ReLU(inplace=True)
    input = V(t.randn(2, 3))
    print(input)
    output = relu(input)
    print(output) # 小于0的都被截断为0
    # 等价于input.clamp(min=0)
    
    Variable containing:
    -0.8472  1.0046  0.7245
     0.3567  0.0032 -0.5200
    [torch.FloatTensor of size 2x3]
    
    Variable containing:
     0.0000  1.0046  0.7245
     0.3567  0.0032  0.0000
    [torch.FloatTensor of size 2x3]
    

    有关inplace:

    ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如variable.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。

    交叉熵

    import torch as t
    from torch.autograd import Variable as V
    
    score = V(t.randn(3,2))
    label = V(t.Tensor([1,0,1])).long()
    loss_fn = t.nn.CrossEntropyLoss()
    loss = loss_fn(score,label)
    print(loss)
    

    Variable containing:
     1.3535
    [torch.FloatTensor of size 1]

    损失函数和nn.Module的其他class没什么不同,不过实际使用时往往单独提取出来(书上语)。

    ReLU(x)=max(0,x)
  • 相关阅读:
    持续集成(转)
    Java中前台JSP请求Servlet实例(http+Servlet)
    Java中Map集合的四种访问方式(转)
    Python中字符串操作
    Python中的range函数用法
    Python学习资料下载地址(转)
    Linux性能工具介绍
    性能问题定位方法
    录制脚本前需要理解的几个基本概念
    Python 硬件底层基础理论
  • 原文地址:https://www.cnblogs.com/hellcat/p/8474254.html
Copyright © 2011-2022 走看看