zoukankan      html  css  js  c++  java
  • pytorch中的nn.CrossEntropyLoss()计算原理

    生成随机矩阵

    x = np.random.rand(2,3) 
    

    array([[0.10786477, 0.56611762, 0.10557245], [0.4596513 , 0.13174377, 0.82373043]])

    计算softmax

    在numpy中

    y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
    

    array([[0.27940617, 0.44182742, 0.27876641], [0.31649398, 0.22801164, 0.45549437]])

    在pytorch中

    torch_x = torch.from_numpy(x) torch_y = nn.Softmax(dim=-1)(torch_x)
    

    tensor([[0.2794, 0.4418, 0.2788], [0.3165, 0.2280, 0.4555]], dtype=torch.float64)

    计算log_softmax

    在numpy中

    import numpy as np 
    x = np.array([[-0.7715, -0.6205,-0.2562]]) 
    y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) 
    y = np.log(y) 
    

    array([[-1.27508877, -0.81683591, -1.27738109], [-1.15045104, -1.47835858, -0.78637192]])

    在pytorch中

    torch_x = torch.from_numpy(x) 
    torch_y = nn.LogSoftmax(dim=-1)(torch_x) 
    

    tensor([[-1.2751, -0.8168, -1.2774], [-1.1505, -1.4784, -0.7864]], dtype=torch.float64)

    计算NLLLoss

    说明,就是在计算log_softmax之后,根据每个样本的真实标签取得其对应的值。默认权重都是1,而且采取求均值的方式。这里就是-(-1.27508877 + -0.78637192) / 2,即取出第0行的第0个和第1行的第2个,正好对应[0, 2]。

    在numpy中

    targets = np.array([0, 2]) 
    nll_loss = -(np.sum(np.choose(targets, y.T)) / y.shape[0]) 
    

    1.0307303437846973

    在pytorch中

    首先我们来看下官方代码:

     |      >>> m = nn.LogSoftmax(dim=1)
     |      >>> loss = nn.NLLLoss()
     |      >>> # input is of size N x C = 3 x 5
     |      >>> input = torch.randn(3, 5, requires_grad=True)
     |      >>> # each element in target has to have 0 <= value < C
     |      >>> target = torch.tensor([1, 0, 4])
     |      >>> output = loss(m(input), target)
     |      >>> output.backward()
    

    发现其也是在计算LogSoftmax之后计算NLLLoss()。
    我们在看下pytorch的计算结果:

    torch_targets = torch.tensor([0, 2])
    torch_nll_loss = nn.NLLLoss()(torch_y, torch_targets)
    

    tensor(1.0307, dtype=torch.float64)
    与我们一步一步利用numpy计算的保持一致。
    最后我们在利用更直观的一种形式来看看:

    import torch.nn.functional as F 
    output = F.nll_loss(F.log_softmax(torch_x, dim=1), torch_targets, reduction='mean')
    

    tensor(1.0307, dtype=torch.float64)
    结果也符合我们的预期。

    https://www.gentlecp.com/articles/874.html
    https://blog.csdn.net/qq_28418387/article/details/95918829 https://blog.csdn.net/yyhhlancelot/article/details/83142255 https://blog.csdn.net/Jeremy_lf/article/details/102725285

  • 相关阅读:
    学习Linux二(创建、删除文件和文件夹命令)
    合理的需求
    两种事件触发的jquery导航菜单
    JS中this关键字
    Hibernate的session问题
    JQUERY图片特效
    学习Linux一(安装VMware和Ubuntu)
    A标签跳转问题
    WEBSERVICE简介
    IE下设置Cursor的一点记录
  • 原文地址:https://www.cnblogs.com/xiximayou/p/15029277.html
Copyright © 2011-2022 走看看