zoukankan      html  css  js  c++  java
  • tensorflow(十四):张量排序( Sort/argsort, Topk, Top5 Acc.)

    一、tf.sort()排序,tf.argsort()排序得到元素index

     

     二、top-k之tf.math.top_k()最大的前k个元素

     

     

     三、实战

    import tensorflow as tf
    import os
    
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
    tf.random.set_seed(2467)
    
    def accuracy(output, target, topk=(1,)):
        maxk = max(topk)                        #这里为6
        batch_size = target.shape[0]            #tensor的行数
    
        pred = tf.math.top_k(output, maxk).indices
        pred = tf.transpose(pred, perm=[1, 0])
        target_ = tf.broadcast_to(target, pred.shape)
        # [10, b]
        correct = tf.equal(pred, target_)
    
        res = []
        for k in topk:
            correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)    #到第几行,参考上一个图片!
            correct_k = tf.reduce_sum(correct_k)
            acc = float(correct_k* (100.0 / batch_size) )
            res.append(acc)
    
        return res
    
    #用正太分布生成10个样本,6类,然后进行一个softmax.
    output = tf.random.normal([10, 6])
    output = tf.math.softmax(output, axis=1)
    target = tf.random.uniform([10], maxval=6, dtype=tf.int32) #0-5之间随机10个数。
    print('prob:', output.numpy())
    pred = tf.argmax(output, axis=1)
    print('pred:', pred.numpy())
    print('label:', target.numpy())
    
    acc = accuracy(output, target, topk=(1,2,3,4,5,6))  #top1~top6的accuracy
    print('top-1-6 acc:', acc)
  • 相关阅读:
    java 求两个数最大值
    java 加法运算
    javs switch 语句
    git合并分支成功,但是push失败(remote: GitLab: You are not allowed to push code to protected branches on this project.)
    python 获取日期以及时间
    1713
    linux shell脚本中的延时
    java 类的继承
    Python3 使用企业微信 API 发送消息
    java if 条件语句
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14607911.html
Copyright © 2011-2022 走看看