zoukankan      html  css  js  c++  java
  • Tensorflow之tf.metrics

    import tensorflow as tf
    import numpy as np
    
    N_TRUE_P = 0
    N_PRED_P = 0
    
    def reset_running_variables():
        """ Resets the previous values of running variables to zero """
        global N_TRUE_P, N_PRED_P
        N_TRUE_P = 0
        N_PRED_P = 0
    
    def update_running_variables(labs, preds):
        global N_TRUE_P, N_PRED_P
        N_TRUE_P += ((labs * preds) > 0).sum()
        N_PRED_P += (preds > 0).sum()
    
    def calculate_precision():
        global N_TRUE_P, N_PRED_P
        return float (N_TRUE_P) / N_PRED_P
    
    if __name__ == '__main__':
    
        labels = np.array([[1,1,1,0],
                           [1,1,1,0],
                           [1,1,1,0],
                           [1,1,1,0]], dtype=np.uint8)
    
        predictions = np.array([[1,0,0,0],
                                [1,1,0,0],
                                [1,1,1,0],
                                [0,1,1,1]], dtype=np.uint8)
    
        n_batches = len(labels)
    
        # #numpy
        # reset_running_variables()
        #
        # for i in range(n_batches):
        #     update_running_variables(labs=labels[i], preds=predictions[i])
        #
        # precision = calculate_precision()
        # print("[NP] SCORE: %1.4f" % precision)
    
        #tensorflow
        graph = tf.Graph()
        with graph.as_default():
            # Placeholders to take in batches onf data
            tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
            tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])
    
            # Define the metric and update operations
            tf_metric, tf_metric_update = tf.metrics.precision(tf_label,
                                                               tf_prediction,
                                                               name="my_metric")
    
            # Isolate the variables stored behind the scenes by the metric operation
            running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")
    
            # Define initializer to initialize/reset running variables
            running_vars_initializer = tf.variables_initializer(var_list=running_vars)
    
        with tf.Session(graph=graph) as session:
            session.run(tf.global_variables_initializer())
    
            # initialize/reset the running variables
            session.run(running_vars_initializer)
    
            for i in range(n_batches):
                # Update the running variables on new batch of samples
                feed_dict = {tf_label: labels[i], tf_prediction: predictions[i]}
                session.run(tf_metric_update, feed_dict=feed_dict)
    
            # Calculate the score
            score = session.run(tf_metric)
            print("[TF] SCORE: %1.4f" % score)
    

     参考:https://zhuanlan.zhihu.com/p/43359894

  • 相关阅读:
    .Net Core 5.x Api开发笔记 -- 消息队列RabbitMQ实现事件总线EventBus(二)
    .Net Core 5.x Api开发笔记 -- 消息队列RabbitMQ实现事件总线EventBus(一)
    SQL 入门教程:创建视图
    微信小程序-企业微信PC端,对接echarts图无法显示
    SQL查看表结构以及表说明
    Skoruba.IdentityServer4.STS.Identity 踩坑
    Docker部署文档
    eCharts图形在IE11中不能渲染
    Cookie中文乱码问题
    Blazor Webassembly多标签页实现
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/12172968.html
Copyright © 2011-2022 走看看