zoukankan      html  css  js  c++  java
  • 多label实现准确率和召回率

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    """
    @File : recall.py
    @Author : 郭凯锋
    @Time : 2020/1/12 17:57
    @Software : PyCharm
    @Git-Hub : daguonice
    @博客园: https://www.cnblogs.com/daguonice/

    """
    import pandas as pd
    import numpy as np
    from sklearn.metrics import recall_score


    def right_or_wrong(ypred_ele, ytrue_ele):
    if ypred_ele in ytrue_ele:
    return True
    else:
    return False


    def get_single_score(y_true, y_pred):
    TP = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 1)))
    FP = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 1)))
    TN = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 0)))
    FN = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 0)))
    recall = [TP, (TP + FN)]
    accuracy = [(TP + TN), len(y_pred)]
    return accuracy, recall


    def get_multi_score(y_true, y_pred):
    list_score = []
    for ele in ['相互宝', '健康险', '保险线', '花呗']:
    y_true_temp = np.zeros(len(y_true))
    y_pred_temp = np.zeros(len(y_pred))
    for idx in range(len(y_true)):
    if y_pred[idx] == ele:
    y_pred_temp[idx] = 1.
    if right_or_wrong(ele, y_true[idx]):
    y_true_temp[idx] = 1.
    accuracy, recall = get_single_score(y_true_temp, y_pred_temp) # 变成百分制
    list_score.append([accuracy, recall])
    return list_score # 返回一个三维列表, 分别是相互宝、健康险、保险线的准确率和召回率的分数


    def fraction2decimal(result):
    res = []
    for ele in result:
    list_ele = []
    for lab in ele:
    list_lab = []
    for score in lab:
    if score[1] == 0:
    list_lab.append(0.0)
    else:
    list_lab.append(round(score[0] / score[1], 2))
    list_ele.append(list_lab)
    res.append(list_ele)
    return res


    def func(y_true, y_pred):
    # y_true是一个一维列表, y_pred是一个二维列表
    if len(y_true) != len(y_pred):
    raise Exception("The two input lengths are inconsistent")
    length = len(y_true)
    list_score = get_multi_score(y_true, y_pred)
    result = []
    temp_list = list_score
    for idx in range(length):
    list_score = temp_list
    if len(y_true[idx]) == 1 and y_pred[idx] not in y_true[idx]:
    list_score[1][1][1] += 1
    list_score[0][0][1] += 1
    result.append(list_score)
    elif len(y_true[idx]) == 2 and y_pred[idx] in y_true[idx]:
    list_score[0][0][0] += 1
    list_score[0][0][1] += 1
    list_score[0][1][0] += 1
    list_score[0][1][1] += 1
    result.append(list_score)
    elif len(y_true[idx]) == 2 and y_pred[idx] not in y_true[idx]:
    list_score[1][1][1] += 1
    list_score[2][1][1] += 1
    list_score[0][0][1] += 1
    result.append(list_score)
    else:
    result.append(list_score)
    result = fraction2decimal(result)
    return result


    if __name__ == '__main__':
    alist = ['相互宝', '相互宝', '相互宝', '健康险', '保险线', '花呗']
    blist = ['健康险', ['健康险', '相互宝'], ['健康险', '保险线'], ['相互宝'], ['保险线'], ['花呗']]
    res = func(blist, alist)
    # res = pd.DataFrame(res, columns=['相互宝', '健康险', '保险线', '花呗'])
    # pd.set_option('display.max_columns', None)
    # pd.set_option('display.max_rows', None)
    # pd.set_option('display.width', 100000)
    # pd.set_option('display.unicode.east_asian_width', True)
    print()
    print(res)
    Done is better than perfect.
  • 相关阅读:
    java学习 接口与继承11 默认方法
    java学习 接口与继承10 内部类
    java学习 接口与继承9 抽象类
    java学习 接口与继承8 final
    理解管理信息系统
    vue中的错误日志
    vue中的ref属性
    2.有24颗外观完全一样的小球,其中有一个是空心的,现在只有一个天平,最少称几次能找出这个特殊的球?
    1.有888瓶编了号码的水及10只健康的小白鼠,其中一瓶水有毒,小白鼠饮用毒水一天后会死,最少需要几天可以找到哪瓶水有毒?
    SQL题1两表联查
  • 原文地址:https://www.cnblogs.com/daguonice/p/12185479.html
Copyright © 2011-2022 走看看