zoukankan      html  css  js  c++  java
  • pytorch的nn.MSELoss损失函数

    MSE是mean squared error的缩写,即平均平方误差,简称均方误差。

    MSE是逐元素计算的,计算公式为:

    旧版的nn.MSELoss()函数有reduce、size_average两个参数,新版的只有一个reduction参数了,功能是一样的。reduction的意思是维度要不要缩减,以及怎么缩减,有三个选项:

    • 'none': no reduction will be applied.
    • 'mean': the sum of the output will be divided by the number of elements in the output.
    • 'sum': the output will be summed.

    如果不设置reduction参数,默认是'mean'。

    程序示例: 

    import torch
    import torch.nn as nn
    
    a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
    b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float)
    
    loss_fn1 = torch.nn.MSELoss(reduction='none')
    loss1 = loss_fn1(a.float(), b.float())
    print(loss1)   # 输出结果:tensor([[ 4.,  9.],
                   #                 [25.,  4.]])
    
    loss_fn2 = torch.nn.MSELoss(reduction='sum')
    loss2 = loss_fn2(a.float(), b.float())
    print(loss2)   # 输出结果:tensor(42.)
    
    
    loss_fn3 = torch.nn.MSELoss(reduction='mean')
    loss3 = loss_fn3(a.float(), b.float())
    print(loss3)   # 输出结果:tensor(10.5000)

     

    对于三维的输入也是一样的:

    a = torch.randint(0, 9, (2, 2, 3)).float()
    b = torch.randint(0, 9, (2, 2, 3)).float()
    print('a:
    ', a)
    print('b:
    ', b)
    
    loss_fn1 = torch.nn.MSELoss(reduction='none')
    loss1 = loss_fn1(a.float(), b.float())
    print('loss_none:
    ', loss1)
    
    loss_fn2 = torch.nn.MSELoss(reduction='sum')
    loss2 = loss_fn2(a.float(), b.float())
    print('loss_sum:
    ', loss2)
    
    
    loss_fn3 = torch.nn.MSELoss(reduction='mean')
    loss3 = loss_fn3(a.float(), b.float())
    print('loss_mean:
    ', loss3)

    运行结果:

     

     参考资料:

    pytorch的nn.MSELoss损失函数

     

  • 相关阅读:
    mysql自增长字段设置
    查看docker的挂载目录
    centos rpm安装jdk1.8
    mybatis-地区三表生成地区树
    post表单、json接口
    git子模块使用
    解决Windows系统80端口被占用
    交换机基础命令
    JMX协议
    WMI协议
  • 原文地址:https://www.cnblogs.com/picassooo/p/13591663.html
Copyright © 2011-2022 走看看