zoukankan      html  css  js  c++  java
  • pytorch之 RNN 参数解释

    上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决。

    总述:
    第一次看到这个函数时,脑袋有点懵,总结了下总共有五个问题:

    1.这个input_size是啥?要输入啥?feature num又是啥?

    2.这个hidden_size是啥?要输入啥?feature num又是啥?

    3.不是说RNN会有很多个节点连在一起的吗?这怎么定义连接的节点数呢?

    4.num_layer中说的stack是怎么stack的?

    5.怎么输出会有两个东西呀output,hn

    pytorch中RNN的一些参数,并且解决以上五个问题

    1.Pytorch中的RNN

     


    2.input_size是啥?
    说白了input_size无非就是你输入RNN的维度,比如说NLP中你需要把一个单词输入到RNN中,这个单词的编码是300维的,那么这个input_size就是300.这里的input_size其实就是规定了你的输入变量的维度。用f(wX+b)来类比的话,这里输入的就是X的维度。

    3.hidden_size是啥?
    和最简单的BP网络一样的,每个RNN的节点实际上就是一个BP嘛,包含输入层,隐含层,输出层。这里的hidden_size呢,你可以看做是隐含层中,隐含节点的个数。

     

     

    那个输入层的三个节点代表输入维度为3,也就是input_size=3,然后这个hidden_size就是5了。当然这是是对于RNN某一个节点而言的,那么如何规定RNN的节点个数呢?

    4.如何规定节点个数?

    事实上,节点个数并不需要规定,你的输入序列是这样子的,[x1,x2,x3,x4,x5],那么input_size呢就是你的xi的维度,而你的RNN的节点数呢,就是由你的序列长度决定的,在这里我们的序列长度是5,所以会有5个节点。那么问题来了,我咋知道你的序列长度呢?pytorch里面不是只有input_size的参数吗?实际上,你声明RNN是这样声明的

    self.encoder = nn.RNN(input_size=300,hidden_size=128,dropout=0.5)
    但是你用的时候;

    output,hn = self.encoder(encoder_input,encoder_hidden)
    你会把你的数据丢进去吧,也就是你把encoder_input这一整个序列丢进去了,那么序列长度他不就知道了?

    5.num_layers是啥?
    一开始你是不是以为这个就是RNN的节点数呀,hhh,然而并不是:),如果num_layer=2的话,表示两个RNN堆叠在一起。那么怎么堆叠的呢?

    如果是num_layer==1的话:

     

    如果num_layer==2的话:

     

    ok了~最后再来看看最后一个问题

    6.hn,output分别是啥?

      hidden的输出size为[ num_layers* num_directions, batch_size, n_hidden].

      说白了,hidden就是每个方向,每个层的 隐藏单元的输出,所以是n_hidden个。

      output的size(如果RNN设定的batch_first=True),那么就是[batch_size,seq_len,n_hidden],对于分类任务如果要取得最后一个output,只需添加下标  [ :,-1,:]

    看图找答案:

     

    hn就是RNN的最后一个隐含状态,output就是RNN最终得到的结果。

  • 相关阅读:
    BZOJ 3744 Gty的妹子序列
    BZOJ 3872 Ant colony
    BZOJ 1087 互不侵犯
    BZOJ 1070 修车
    BZOJ 2654 tree
    BZOJ 3243 向量内积
    1003 NOIP 模拟赛Day2 城市建设
    CF865D Buy Low Sell High
    CF444A DZY Loves Physics
    Luogu 4310 绝世好题
  • 原文地址:https://www.cnblogs.com/dhName/p/11760610.html
Copyright © 2011-2022 走看看