zoukankan      html  css  js  c++  java
  • tensorflow计算各个类别的正确率

    import tensorflow as tf
    
    def count_nums(true_labels, num_classes):
        initial_value = 0
        list_length = num_classes
        list_data = [ initial_value for i in range(list_length)]
        for i in range(0, num_classes):
            list_data[i] = true_labels.count(i)
        return list_data
    
    def accuracy(confusion_matrix, true_labels, num_classes):
        # 各个类别的测试样本的个数
        list_data = count_nums(true_labels, num_classes)
    
        # 各个类别正确分类的个数
        initial_value = 0
        list_length = num_classes
        true_pred = [ initial_value for i in range(list_length)]
        for i in range(0,5):
            true_pred[i] = confusion_matrix[i][i]
    
        # 计算各个样本被正确分类的正确率
        acc = []
        for i in range(0, 5):
            acc.append(0)
    
        for i in range(0,5):
            acc[i] = true_pred[i] / list_data[i]
    
        return acc
    
    # 测试数据
    y_true = [0, 1, 2, 3, 1, 2, 3, 4, 1] # 真实的标签
    y_pred = [1, 1, 2, 3, 1, 2, 3, 4, 2] # 预测的标签
    
    # Build graph with tf.confusion_matrix operation
    sess = tf.InteractiveSession()
    op = tf.confusion_matrix(y_true, y_pred)
    # Execute the graph
    print ("confusion matrix in tensorflow: ")
    confusion_matrix = sess.run(op)
    print(confusion_matrix)
    sess.close()
    
    # 计算各个类别的正确率
    acc = accuracy(confusion_matrix, y_true, num_classes = 5)
    print(acc)
    

      

  • 相关阅读:
    英语老师不想让你知道的一些网站分享
    最近三周开发的桌面应用系统
    UML技术沙龙PPT
    Pandas时间处理的一些小方法
    合并函数总结
    开博宣言
    DBGrid中增加一列CHECKBOX
    关于Delphi的Hint
    操作EXCEL
    关于FastReport
  • 原文地址:https://www.cnblogs.com/wylwyl/p/10864217.html
Copyright © 2011-2022 走看看