zoukankan      html  css  js  c++  java
  • Tensorflow之tf.metrics

    import tensorflow as tf
    import numpy as np
    
    N_TRUE_P = 0
    N_PRED_P = 0
    
    def reset_running_variables():
        """ Resets the previous values of running variables to zero """
        global N_TRUE_P, N_PRED_P
        N_TRUE_P = 0
        N_PRED_P = 0
    
    def update_running_variables(labs, preds):
        global N_TRUE_P, N_PRED_P
        N_TRUE_P += ((labs * preds) > 0).sum()
        N_PRED_P += (preds > 0).sum()
    
    def calculate_precision():
        global N_TRUE_P, N_PRED_P
        return float (N_TRUE_P) / N_PRED_P
    
    if __name__ == '__main__':
    
        labels = np.array([[1,1,1,0],
                           [1,1,1,0],
                           [1,1,1,0],
                           [1,1,1,0]], dtype=np.uint8)
    
        predictions = np.array([[1,0,0,0],
                                [1,1,0,0],
                                [1,1,1,0],
                                [0,1,1,1]], dtype=np.uint8)
    
        n_batches = len(labels)
    
        # #numpy
        # reset_running_variables()
        #
        # for i in range(n_batches):
        #     update_running_variables(labs=labels[i], preds=predictions[i])
        #
        # precision = calculate_precision()
        # print("[NP] SCORE: %1.4f" % precision)
    
        #tensorflow
        graph = tf.Graph()
        with graph.as_default():
            # Placeholders to take in batches onf data
            tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
            tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])
    
            # Define the metric and update operations
            tf_metric, tf_metric_update = tf.metrics.precision(tf_label,
                                                               tf_prediction,
                                                               name="my_metric")
    
            # Isolate the variables stored behind the scenes by the metric operation
            running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")
    
            # Define initializer to initialize/reset running variables
            running_vars_initializer = tf.variables_initializer(var_list=running_vars)
    
        with tf.Session(graph=graph) as session:
            session.run(tf.global_variables_initializer())
    
            # initialize/reset the running variables
            session.run(running_vars_initializer)
    
            for i in range(n_batches):
                # Update the running variables on new batch of samples
                feed_dict = {tf_label: labels[i], tf_prediction: predictions[i]}
                session.run(tf_metric_update, feed_dict=feed_dict)
    
            # Calculate the score
            score = session.run(tf_metric)
            print("[TF] SCORE: %1.4f" % score)
    

     参考:https://zhuanlan.zhihu.com/p/43359894

  • 相关阅读:
    正则表达式
    数据结构与算法-串
    数据结构与算法-优先级队列
    数据结构与算法-词典
    数据结构与算法-高级搜索树
    数据结构与算法-二叉搜索树
    数据结构与算法-图
    数据结构与算法-二叉树
    数据结构与算法-栈与队列
    数据结构与算法-列表
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/12172968.html
Copyright © 2011-2022 走看看