zoukankan      html  css  js  c++  java
  • torch:CrossEntropy是个构造器,所以loss = torch.nn.CrossEntropyLoss()(output, target)这么写就对了

    criteria = nn.CrossEntropyLoss()
    loss = criteria(output, target)
    loss = torch.nn.functional.cross_entropy(output, target)
    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    
    # input is of size N x C = 3 x 5
    input = torch.randn(3, 5, requires_grad=True)
    # each element in target has to have 0 <= value < C
    target = torch.tensor([1, 0, 4])
    output = F.nll_loss(F.log_softmax(input), target)
    print(output)
    output.backward()
    print(output)
    
    input = torch.tensor([[0.1939, 0.2019, 0.8598],
                          [0.4146, 0.1330, 0.9469],
                          [0.8549, 0.9154, 0.5434]])
    # input = torch.rand(3, 3)
    
    print(input)
    sm = nn.Softmax(dim=1)
    print(sm(input))
    # tensor([[0.2529, 0.2549, 0.4922],
    #         [0.2892, 0.2182, 0.4925],
    #         [0.3578, 0.3801, 0.2620]])
    print(torch.log(sm(input)))
    # tensor([[-1.3748, -1.3668, -0.7089],
    #         [-1.2406, -1.5221, -0.7082],
    #         [-1.0277, -0.9672, -1.3392]])
    # tar = torch.tensor([0,2,1])
    tar = torch.tensor([0,2,1])
    # targ = nn.NLLLoss(input,tar) #loss = torch.nn.functional.cross_entropy(output, target)
    targ = F.nll_loss(input,tar)
    print(targ)
    D:ProgramDataMiniconda3python.exe E:/新脚本主文件夹/训练测试项目/test_torch/nll_loss.py
    E:/新脚本主文件夹/训练测试项目/test_torch/nll_loss.py:10: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
      output = F.nll_loss(F.log_softmax(input), target)
    tensor(3.2477, grad_fn=<NllLossBackward>)
    tensor(3.2477, grad_fn=<NllLossBackward>)
    tensor([[0.1939, 0.2019, 0.8598],
            [0.4146, 0.1330, 0.9469],
            [0.8549, 0.9154, 0.5434]])
    tensor([[0.2529, 0.2549, 0.4922],
            [0.2892, 0.2182, 0.4925],
            [0.3578, 0.3801, 0.2620]])
    tensor([[-1.3748, -1.3668, -0.7089],
            [-1.2405, -1.5221, -0.7082],
            [-1.0277, -0.9672, -1.3392]])
    tensor(-0.6854)
    
    Process finished with exit code 0
  • 相关阅读:
    CodeForces 156B Suspects(枚举)
    CodeForces 156A Message(暴力)
    CodeForces 157B Trace
    CodeForces 157A Game Outcome
    HDU 3578 Greedy Tino(双塔DP)
    POJ 2609 Ferry Loading(双塔DP)
    Java 第十一届 蓝桥杯 省模拟赛 19000互质的个数
    Java 第十一届 蓝桥杯 省模拟赛 19000互质的个数
    Java 第十一届 蓝桥杯 省模拟赛 19000互质的个数
    Java 第十一届 蓝桥杯 省模拟赛十六进制转换成十进制
  • 原文地址:https://www.cnblogs.com/DDBD/p/14063815.html
Copyright © 2011-2022 走看看