zoukankan      html  css  js  c++  java
  • pytorch构建自己设计的层

    下面是如何自己构建一个层,分为包含自动反向求导和手动反向求导两种方式,后面会分别构建网络,对比一下结果对不对。

    ----------------------------------------------------------

    关于Pytorch中的结构层级关系。

    最为底层的是torch.relu()、torch.tanh()、torch.ge()这些函数,这些函数个人猜测就是直接用Cuda写成的,并且封装成了python接口给python上层调用。

    部分函数被torch.nn.functional里面的部分函数模块调用。这些函数可能会被更为上层的nn.Module调用。

    下面以BatchNormalization为例进行分析。

    最为底层的是torch.batch_norm()这个函数,是看不到源代码的,应该是对于cuda代码的封装。这个函数会传入(input, weight, bias, running_mean, running_var, training, momentum, eps)。 再往上时torch.nn.functional里面的函数bacth_norm()。再往上就是torch.nn里面的网络层,比如,BatchNorm2d()等等。

    分析一下BatchNorm2d()里面的主要程序。

     

    import torch
    import torch.nn as nn
    from torch.nn import init
    from torch.nn.parameter import Parameter
    
    class BatchNorm(nn.module):
        def __init__(self,num_features):
            super(BatchMMNorm,self).__init__()
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))        
            
        def reset_parameter(self):
            init.uniform_(self.weight)
            init.zeros_(self.bias)

      def forward(self,input):

     

    其中Parameter是用以定义可学习的权重参数的,后面还需要初始化参数。

     

     

     

     

     

     

  • 相关阅读:
    react.js 小记
    docker命令收集
    前端体系
    微信小程序疑问解答
    微信小程序实战笔记
    jQuery.zTree的跳坑记录
    移动web端的react.js组件化方案
    深入理解SQL的四种连接-左外连接、右外连接、内连接、全连接
    字符串转换成数组,去最号的分割号
    linq any()方法实现sql in()方法的效果
  • 原文地址:https://www.cnblogs.com/yanxingang/p/10414693.html
Copyright © 2011-2022 走看看