zoukankan      html  css  js  c++  java
  • torch.nn.functional中softmax的作用及其参数说明

     参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/functional/#_1

    class torch.nn.Softmax(input, dim)

    或:

    torch.nn.functional.softmax(input, dim)

    对n维输入张量运用Softmax函数,将张量的每个元素缩放到(0,1)区间且和为1。Softmax函数定义如下:

    参数:

      dim:指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告:

    UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

    shape:

    • 输入:(N, L)
    • 输出:(N, L)

    返回结果是一个与输入维度dim相同的张量,每个元素的取值范围在(0,1)区间。

    例子:

    import torch
    
    from torch import nn
    from torch import autograd
    
    m = nn.Softmax()
    input = autograd.Variable(torch.randn(2, 3))
    print(input)
    print(m(input))

    返回:

    (deeplearning) userdeMBP:pytorch user$ python test.py 
    tensor([[ 0.2854,  0.1708,  0.4308],
            [-0.1983,  2.0705,  0.1549]])
    test.py:9: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
      print(m(input))
    tensor([[0.3281, 0.2926, 0.3794],
            [0.0827, 0.7996, 0.1177]])

    可见默认按行计算,即dim=1

    更明显的例子:

    import torch
    
    import torch.nn.functional as F
    
    x= torch.Tensor( [ [1,2,3,4],[1,2,3,4],[1,2,3,4]])
    
    y1= F.softmax(x, dim = 0) #对每一列进行softmax
    print(y1)
    
    y2 = F.softmax(x,dim =1) #对每一行进行softmax
    print(y2)
    
    x1 = torch.Tensor([1,2,3,4])
    print(x1)
    
    y3 = F.softmax(x1,dim=0) #一维时使用dim=0,使用dim=1报错
    print(y3)

    返回:

    (deeplearning) userdeMBP:pytorch user$ python test.py 
    tensor([[0.3333, 0.3333, 0.3333, 0.3333],
            [0.3333, 0.3333, 0.3333, 0.3333],
            [0.3333, 0.3333, 0.3333, 0.3333]])
    tensor([[0.0321, 0.0871, 0.2369, 0.6439],
            [0.0321, 0.0871, 0.2369, 0.6439],
            [0.0321, 0.0871, 0.2369, 0.6439]])
    tensor([1., 2., 3., 4.])
    tensor([0.0321, 0.0871, 0.2369, 0.6439])

    因为列的值相同,所以按列计算时每一个所占的比重都是0.3333;行都是[1,2,3,4],所以按行计算,比重结果都为[0.0321, 0.0871, 0.2369, 0.6439]

    一维使用dim=1报错:

    RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
  • 相关阅读:
    Spring boot 使用多个RedisTemplate
    Spring boot 连接Redis实现HMSET操作
    Spring boot 工具类静态属性注入及多环境配置
    向量空间模型(Vector Space Model)的理解
    双数组Trie树中叶子结点check[t]=t的证明
    谈谈我对隐马尔可夫模型的理解
    Information Retrieval 倒排索引 学习笔记
    朴素贝叶斯文本分类简单介绍
    Python Thrift 简单示例
    迭代器模式(Iterator)
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/10675588.html
Copyright © 2011-2022 走看看