zoukankan      html  css  js  c++  java
  • 多label实现准确率和召回率

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    """
    @File : recall.py
    @Author : 郭凯锋
    @Time : 2020/1/12 17:57
    @Software : PyCharm
    @Git-Hub : daguonice
    @博客园: https://www.cnblogs.com/daguonice/

    """
    import pandas as pd
    import numpy as np
    from sklearn.metrics import recall_score


    def right_or_wrong(ypred_ele, ytrue_ele):
    if ypred_ele in ytrue_ele:
    return True
    else:
    return False


    def get_single_score(y_true, y_pred):
    TP = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 1)))
    FP = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 1)))
    TN = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 0)))
    FN = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 0)))
    recall = [TP, (TP + FN)]
    accuracy = [(TP + TN), len(y_pred)]
    return accuracy, recall


    def get_multi_score(y_true, y_pred):
    list_score = []
    for ele in ['相互宝', '健康险', '保险线', '花呗']:
    y_true_temp = np.zeros(len(y_true))
    y_pred_temp = np.zeros(len(y_pred))
    for idx in range(len(y_true)):
    if y_pred[idx] == ele:
    y_pred_temp[idx] = 1.
    if right_or_wrong(ele, y_true[idx]):
    y_true_temp[idx] = 1.
    accuracy, recall = get_single_score(y_true_temp, y_pred_temp) # 变成百分制
    list_score.append([accuracy, recall])
    return list_score # 返回一个三维列表, 分别是相互宝、健康险、保险线的准确率和召回率的分数


    def fraction2decimal(result):
    res = []
    for ele in result:
    list_ele = []
    for lab in ele:
    list_lab = []
    for score in lab:
    if score[1] == 0:
    list_lab.append(0.0)
    else:
    list_lab.append(round(score[0] / score[1], 2))
    list_ele.append(list_lab)
    res.append(list_ele)
    return res


    def func(y_true, y_pred):
    # y_true是一个一维列表, y_pred是一个二维列表
    if len(y_true) != len(y_pred):
    raise Exception("The two input lengths are inconsistent")
    length = len(y_true)
    list_score = get_multi_score(y_true, y_pred)
    result = []
    temp_list = list_score
    for idx in range(length):
    list_score = temp_list
    if len(y_true[idx]) == 1 and y_pred[idx] not in y_true[idx]:
    list_score[1][1][1] += 1
    list_score[0][0][1] += 1
    result.append(list_score)
    elif len(y_true[idx]) == 2 and y_pred[idx] in y_true[idx]:
    list_score[0][0][0] += 1
    list_score[0][0][1] += 1
    list_score[0][1][0] += 1
    list_score[0][1][1] += 1
    result.append(list_score)
    elif len(y_true[idx]) == 2 and y_pred[idx] not in y_true[idx]:
    list_score[1][1][1] += 1
    list_score[2][1][1] += 1
    list_score[0][0][1] += 1
    result.append(list_score)
    else:
    result.append(list_score)
    result = fraction2decimal(result)
    return result


    if __name__ == '__main__':
    alist = ['相互宝', '相互宝', '相互宝', '健康险', '保险线', '花呗']
    blist = ['健康险', ['健康险', '相互宝'], ['健康险', '保险线'], ['相互宝'], ['保险线'], ['花呗']]
    res = func(blist, alist)
    # res = pd.DataFrame(res, columns=['相互宝', '健康险', '保险线', '花呗'])
    # pd.set_option('display.max_columns', None)
    # pd.set_option('display.max_rows', None)
    # pd.set_option('display.width', 100000)
    # pd.set_option('display.unicode.east_asian_width', True)
    print()
    print(res)
    Done is better than perfect.
  • 相关阅读:
    ng-深度学习-课程笔记-1: 介绍深度学习(Week1)
    java发送http请求和多线程
    Spring Cloud Eureka注册中心(快速搭建)
    Spring boot集成Swagger2,并配置多个扫描路径,添加swagger-ui-layer
    springboot在idea的RunDashboard如何显示出来
    Oracle 中select XX_id_seq.nextval from dual 什么意思呢?
    mysql类似to_char()to_date()函数mysql日期和字符相互转换方法date_f
    MySQL的Limit详解
    HikariCP 个人实例
    NBA-2018骑士季后赛
  • 原文地址:https://www.cnblogs.com/daguonice/p/12185479.html
Copyright © 2011-2022 走看看