zoukankan      html  css  js  c++  java
  • 张量排序

    Outline

    • Sort/argsort

    • Topk

    • Top-5 Acc.

    Sort/argsort

    一维

    import tensorflow as tf
    
    a = tf.random.shuffle(tf.range(5))
    a
    
    <tf.Tensor: id=59, shape=(5,), dtype=int32, numpy=array([4, 0, 3, 2, 1], dtype=int32)>
    
    tf.sort(a, direction='DESCENDING')
    
    <tf.Tensor: id=69, shape=(5,), dtype=int32, numpy=array([4, 3, 2, 1, 0], dtype=int32)>
    
    # 返回索引
    tf.argsort(a, direction='DESCENDING')
    
    <tf.Tensor: id=81, shape=(5,), dtype=int32, numpy=array([0, 2, 3, 4, 1], dtype=int32)>
    
    idx = tf.argsort(a, direction='DESCENDING')
    tf.gather(a, idx)
    
    <tf.Tensor: id=94, shape=(5,), dtype=int32, numpy=array([4, 3, 2, 1, 0], dtype=int32)>
    

    二维

    a = tf.random.uniform([3, 3], maxval=10, dtype=tf.int32)
    a
    
    <tf.Tensor: id=99, shape=(3, 3), dtype=int32, numpy=
    array([[1, 9, 4],
           [2, 1, 4],
           [3, 6, 0]], dtype=int32)>
    
    tf.sort(a)
    
    <tf.Tensor: id=112, shape=(3, 3), dtype=int32, numpy=
    array([[1, 4, 9],
           [1, 2, 4],
           [0, 3, 6]], dtype=int32)>
    
    tf.sort(a, direction='DESCENDING')
    
    <tf.Tensor: id=122, shape=(3, 3), dtype=int32, numpy=
    array([[9, 4, 1],
           [4, 2, 1],
           [6, 3, 0]], dtype=int32)>
    
    idx = tf.argsort(a)
    idx
    
    <tf.Tensor: id=146, shape=(3, 3), dtype=int32, numpy=
    array([[0, 2, 1],
           [1, 0, 2],
           [2, 0, 1]], dtype=int32)>
    

    Top_k

    • Only return top-k values and indices

    Top_one

    a
    
    <tf.Tensor: id=99, shape=(3, 3), dtype=int32, numpy=
    array([[1, 9, 4],
           [2, 1, 4],
           [3, 6, 0]], dtype=int32)>
    
    # 返回前2个值
    res = tf.math.top_k(a, 2)
    res
    
    TopKV2(values=<tf.Tensor: id=160, shape=(3, 2), dtype=int32, numpy=
    array([[9, 4],
           [4, 2],
           [6, 3]], dtype=int32)>, indices=<tf.Tensor: id=161, shape=(3, 2), dtype=int32, numpy=
    array([[1, 2],
           [2, 0],
           [1, 0]], dtype=int32)>)
    
    res.values
    
    <tf.Tensor: id=160, shape=(3, 2), dtype=int32, numpy=
    array([[9, 4],
           [4, 2],
           [6, 3]], dtype=int32)>
    
    res.indices
    
    <tf.Tensor: id=161, shape=(3, 2), dtype=int32, numpy=
    array([[1, 2],
           [2, 0],
           [1, 0]], dtype=int32)>
    

    Top-k accuracy

    • Prob:[0.1,0.2,0.3,0.4]

    • Lable:[2]

    • Only consider top-1 prediction:[3]

    • Only consider top-2 prediction:[3,2]

    • Only consider top-3 prediction:[3,2,1]

    prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.7, 0.1]])
    target = tf.constant([2, 0])
    
    # 概率最大的索引在最前面
    k_b = tf.math.top_k(prob, 3).indices
    k_b
    
    <tf.Tensor: id=190, shape=(2, 3), dtype=int32, numpy=
    array([[2, 1, 0],
           [1, 0, 2]], dtype=int32)>
    
    k_b = tf.transpose(k_b, [1, 0])
    k_b
    
    <tf.Tensor: id=193, shape=(3, 2), dtype=int32, numpy=
    array([[2, 1],
           [1, 0],
           [0, 2]], dtype=int32)>
    
    # 对真实值broadcast,与prod比较
    target = tf.broadcast_to(target, [3, 2])
    target
    
    <tf.Tensor: id=196, shape=(3, 2), dtype=int32, numpy=
    array([[2, 0],
           [2, 0],
           [2, 0]], dtype=int32)>
    

    示例

    def accuracy(output, target, topk=(1, )):
        maxk = max(topk)
        batch_size = target.shape[0]
    
        pred = tf.math.top_k(output, maxk).indices
        pred = tf.transpose(pred, perm=[1, 0])
        target_ = tf.broadcast_to(target, pred.shape)
        correct = tf.equal(pred, target_)
    
        res = []
        for k in topk:
            correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
            correct_k = tf.reduce_sum(correct_k)
            acc = float(correct_k / batch_size)
            res.append(acc)
    
        return res
    
    # 10个样本6类
    output = tf.random.normal([10, 6])
    # 使得所有样本的概率加起来为1
    output = tf.math.softmax(output, axis=1)
    # 10个样本对应的标记
    target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
    print(f'prob: {output.numpy()}')
    pred = tf.argmax(output, axis=1)
    print(f'pred: {pred.numpy()}')
    print(f'label: {target.numpy()}')
    
    acc = accuracy(output, target, topk=(1, 2, 3, 4, 5, 6))
    print(f'top-1-6 acc: {acc}')
    
    prob: [[0.12232917 0.18645659 0.27771464 0.17322136 0.14854735 0.09173083]
     [0.02338449 0.01026637 0.11773597 0.69083494 0.03814701 0.11963127]
     [0.05774692 0.1926369  0.49359822 0.10262781 0.10738047 0.0460096 ]
     [0.21298195 0.02826484 0.1813868  0.06380058 0.06848615 0.44507968]
     [0.01364106 0.16782394 0.08621352 0.22500433 0.19081964 0.31649753]
     [0.02917767 0.15526605 0.6310118  0.11471876 0.05473462 0.0150911 ]
     [0.03684716 0.15286008 0.11792535 0.47401306 0.05833342 0.160021  ]
     [0.32859987 0.17415446 0.07394216 0.22221863 0.07559296 0.12549189]
     [0.02662764 0.5529567  0.06995299 0.02131662 0.08664025 0.2425058 ]
     [0.10253917 0.10178788 0.21553555 0.12878521 0.3788466  0.07250563]]
    pred: [2 3 2 5 5 2 3 0 1 4]
    label: [3 4 3 0 4 0 3 2 1 4]
    top-1-6 acc: [0.30000001192092896, 0.4000000059604645, 0.6000000238418579, 0.800000011920929, 0.8999999761581421, 1.0]
  • 相关阅读:
    spring中的Filter使用
    跨站脚本(XSS)攻击
    RepeatSubmitInterceptor extends HandlerInterceptorAdapter
    理解TCP
    Github(第一次尝试)
    MVC(实战二:网址映射)
    MVC(实战一)
    MVC(基础二)
    WinFrom和WebFrom的区别
    MVC(基础一)
  • 原文地址:https://www.cnblogs.com/nickchen121/p/10852639.html
Copyright © 2011-2022 走看看