zoukankan      html  css  js  c++  java
  • (predicted == labels).sum().item()作用

     ⚠️(predicted == labels).sum().item()作用,举个小例子介绍:

    # -*- coding: utf-8 -*-
    import torch import numpy as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [2,3,4] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) #对比后相同的值会为1,不同则会为0 print p print type(p) d1 = p.sum() #将所有的值相加,得到的仍是tensor类别的int值 print d1 print type(d1) d2 = d1.item() #转成python数字 print d2 print type(d2)

    返回:

    (deeplearning2) userdeMBP:pytorch user$ python test.py
    tensor([[1, 1, 1],
            [1, 1, 1]], dtype=torch.uint8)
    <class 'torch.Tensor'>
    tensor(6)
    <class 'torch.Tensor'>
    6
    <type 'int'>

    即如果有不同的话,会变成:

    # -*- coding: utf-8 -*-
    import torch import numpy
    as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [4,5,6] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) print p print type(p) d1 = p.sum() print d1 print type(d1) d2 = d1.item() print d2 print type(d2)

    返回:

    (deeplearning2) userdeMBP:pytorch user$ python test.py
    tensor([[1, 1, 1],
            [0, 0, 0]], dtype=torch.uint8)
    <class 'torch.Tensor'>
    tensor(3)
    <class 'torch.Tensor'>
    3
    <type 'int'>
  • 相关阅读:
    面试总结
    CentOS 6.4 yum安装LAMP环境
    windows下XAMPP安装php_memcache扩展
    linux学习笔记
    本地虚拟机LNMP环境安装
    Linux下php安装memcache扩展
    linux下memcached安装以及启动
    阿里云服务器---centos编译安装ffmpeg
    [Yii2.0] 以Yii 2.0风格加载自定义类或命名空间 [配置使用Yii2 autoloader]
    Linux常用命令
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/10558819.html
Copyright © 2011-2022 走看看