zoukankan      html  css  js  c++  java
  • torch.nn.Linear()函数的理解

    import torch

    x = torch.randn(128, 20) # 输入的维度是(128,20)
    m = torch.nn.Linear(20, 30) # 20,30是指维度
    output = m(x)
    print('m.weight.shape: ', m.weight.shape)
    print('m.bias.shape: ', m.bias.shape)
    print('output.shape: ', output.shape)

    # ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
    ans = torch.mm(x, m.weight.t()) + m.bias
    print('ans.shape: ', ans.shape)

    print(torch.equal(ans, output))
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    m.weight.shape:
    torch.Size([30, 20])
    m.bias.shape:
    torch.Size([30])
    output.shape:
    torch.Size([128, 30])
    ans.shape:
    torch.Size([128, 30])
    True
    1
    2
    3
    4
    5
    6
    7
    8
    9
    为什么 m.weight.shape = (30,20)?

    答:因为线性变换的公式是:

    y=xAT+b y=xA^T+b
    y=xA
    T
    +b

    先生成一个(30,20)的weight,实际运算中再转置,这样就能和x做矩阵乘法了
    ---------------------
    作者:m0_37586991
    来源:CSDN
    原文:https://blog.csdn.net/m0_37586991/article/details/87861418
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    【2019-12-13】泛型
    【2019-12-12】函数
    【2019-12-10】类
    【2019-12-05】接口
    【2019-12-3】变量声明
    【2019-11-24】基础类型
    【2019-11-20】服务与DI简介
    【2019-11-20】组件简介
    android之ListView与Adapter(结合JavaBean)
    android基类Adapter
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11068544.html
Copyright © 2011-2022 走看看