zoukankan      html  css  js  c++  java
  • sklearn里计算roc_auc_score,报错ValueError: bad input shape

    用sklearn的DecisionTreeClassifer训练模型,然后用roc_auc_score计算模型的auc。代码如下

    clf = DecisionTreeClassifier(criterion='gini', max_depth=6, min_samples_split=10, min_samples_leaf=2)
    clf.fit(X_train, y_train)
    y_pred = clf.predict_proba(X_test)
    roc_auc = roc_auc_score(y_test, y_pred)

    报错信息如下

    /Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/metrics/ranking.pyc in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
        297     check_consistent_length(y_true, y_score)
        298     y_true = column_or_1d(y_true)
    --> 299     y_score = column_or_1d(y_score)
        300     assert_all_finite(y_true)
        301     assert_all_finite(y_score)
    
    /Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in column_or_1d(y, warn)
        560         return np.ravel(y)
        561 
    --> 562     raise ValueError("bad input shape {0}".format(shape))
        563 
        564 
    
    ValueError: bad input shape (900, 2)

    目测是你的y_pred出了问题,你的y_pred是(900, 2)的array,也就是有两列。

    因为predict_proba返回的是两列。predict_proba的用法参考这里

    简而言之,你上面的代码改成这样就可以了。

    y_pred = clf.predict_proba(X_test)[:, 1]
    roc_auc = roc_auc_score(y_test, y_pred)

    原文:http://sofasofa.io/forum_main_post.php?postid=1001678

  • 相关阅读:
    MongoDB插入时间不正确的问题
    json 字符串转换成对象,对象转换成json字符串
    sqlServer sa用户登陆失败的解决办法
    基于web工作流开发
    javascript ajax的语法
    收藏和设为首页的方法
    asp.net收藏和设为首页的代码
    设计模式
    设计模式
    设计模式
  • 原文地址:https://www.cnblogs.com/anovana/p/11750285.html
Copyright © 2011-2022 走看看