zoukankan      html  css  js  c++  java
  • torch:CrossEntropy是个构造器,所以loss = torch.nn.CrossEntropyLoss()(output, target)这么写就对了

    criteria = nn.CrossEntropyLoss()
    loss = criteria(output, target)
    loss = torch.nn.functional.cross_entropy(output, target)
    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    
    # input is of size N x C = 3 x 5
    input = torch.randn(3, 5, requires_grad=True)
    # each element in target has to have 0 <= value < C
    target = torch.tensor([1, 0, 4])
    output = F.nll_loss(F.log_softmax(input), target)
    print(output)
    output.backward()
    print(output)
    
    input = torch.tensor([[0.1939, 0.2019, 0.8598],
                          [0.4146, 0.1330, 0.9469],
                          [0.8549, 0.9154, 0.5434]])
    # input = torch.rand(3, 3)
    
    print(input)
    sm = nn.Softmax(dim=1)
    print(sm(input))
    # tensor([[0.2529, 0.2549, 0.4922],
    #         [0.2892, 0.2182, 0.4925],
    #         [0.3578, 0.3801, 0.2620]])
    print(torch.log(sm(input)))
    # tensor([[-1.3748, -1.3668, -0.7089],
    #         [-1.2406, -1.5221, -0.7082],
    #         [-1.0277, -0.9672, -1.3392]])
    # tar = torch.tensor([0,2,1])
    tar = torch.tensor([0,2,1])
    # targ = nn.NLLLoss(input,tar) #loss = torch.nn.functional.cross_entropy(output, target)
    targ = F.nll_loss(input,tar)
    print(targ)
    D:ProgramDataMiniconda3python.exe E:/新脚本主文件夹/训练测试项目/test_torch/nll_loss.py
    E:/新脚本主文件夹/训练测试项目/test_torch/nll_loss.py:10: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
      output = F.nll_loss(F.log_softmax(input), target)
    tensor(3.2477, grad_fn=<NllLossBackward>)
    tensor(3.2477, grad_fn=<NllLossBackward>)
    tensor([[0.1939, 0.2019, 0.8598],
            [0.4146, 0.1330, 0.9469],
            [0.8549, 0.9154, 0.5434]])
    tensor([[0.2529, 0.2549, 0.4922],
            [0.2892, 0.2182, 0.4925],
            [0.3578, 0.3801, 0.2620]])
    tensor([[-1.3748, -1.3668, -0.7089],
            [-1.2405, -1.5221, -0.7082],
            [-1.0277, -0.9672, -1.3392]])
    tensor(-0.6854)
    
    Process finished with exit code 0
  • 相关阅读:
    java socket解析和发送二进制报文工具(附java和C++转化问题)
    hibernate缓存机制(二级缓存)
    Spring MVC中前后台数据传输小结
    NUC972 MDK NON-OS
    代码是如何控制硬件的?
    C语言位运算+实例讲解(转)
    C语言程序真正的启动函数
    51单片机的时钟及总线时序和总线扩展
    如何通过Keil将程序正确的下载进flash中
    说说M451例程讲解之串口
  • 原文地址:https://www.cnblogs.com/DDBD/p/14063815.html
Copyright © 2011-2022 走看看