zoukankan      html  css  js  c++  java
  • 吴裕雄--天生自然TensorFlow2教程:张量排序

    import tensorflow as tf
    
    a = tf.random.shuffle(tf.range(5))
    a
    tf.sort(a, direction='DESCENDING')
    # 返回索引
    tf.argsort(a, direction='DESCENDING')
    idx = tf.argsort(a, direction='DESCENDING')
    tf.gather(a, idx)
    idx = tf.argsort(a, direction='DESCENDING')
    tf.gather(a, idx)
    a = tf.random.uniform([3, 3], maxval=10, dtype=tf.int32)
    a
    tf.sort(a)
    tf.sort(a, direction='DESCENDING')
    idx = tf.argsort(a)
    idx
    # 返回前2个值
    res = tf.math.top_k(a, 2)
    res
    res.values
    res.indices
    prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.7, 0.1]])
    target = tf.constant([2, 0])
    # 概率最大的索引在最前面
    k_b = tf.math.top_k(prob, 3).indices
    k_b
    k_b = tf.transpose(k_b, [1, 0])
    k_b
    # 对真实值broadcast,与prod比较
    target = tf.broadcast_to(target, [3, 2])
    target
    def accuracy(output, target, topk=(1, )):
        maxk = max(topk)
        batch_size = target.shape[0]
    
        pred = tf.math.top_k(output, maxk).indices
        pred = tf.transpose(pred, perm=[1, 0])
        target_ = tf.broadcast_to(target, pred.shape)
        correct = tf.equal(pred, target_)
    
        res = []
        for k in topk:
            correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
            correct_k = tf.reduce_sum(correct_k)
            acc = float(correct_k / batch_size)
            res.append(acc)
    
        return res
    # 10个样本6类
    output = tf.random.normal([10, 6])
    # 使得所有样本的概率加起来为1
    output = tf.math.softmax(output, axis=1)
    # 10个样本对应的标记
    target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
    print(f'prob: {output.numpy()}')
    pred = tf.argmax(output, axis=1)
    print(f'pred: {pred.numpy()}')
    print(f'label: {target.numpy()}')
    
    acc = accuracy(output, target, topk=(1, 2, 3, 4, 5, 6))
    print(f'top-1-6 acc: {acc}')
  • 相关阅读:
    django2.0+连接mysql数据库迁移时候报错
    微信小程序路由跳转
    洛谷P3144 [USACO16OPEN]关闭农场Closing the Farm
    洛谷P3143 [USACO16OPEN]钻石收藏家Diamond Collector
    洛谷P2677 超级书架 2
    洛谷P2676 超级书架
    洛谷P3146 [USACO16OPEN]248
    洛谷P1396 营救
    洛谷P1772 [ZJOI2006]物流运输
    P3102 [USACO14FEB]秘密代码Secret Code
  • 原文地址:https://www.cnblogs.com/tszr/p/12133374.html
Copyright © 2011-2022 走看看