zoukankan      html  css  js  c++  java
  • Scikit-learn 多标签分类 multilabel classification(大量训练数据,MultiOutputClassifier,partial_fit)

    核心代码:

    # from sklearn.linear_model import LogisticRegression
    from sklearn.multioutput import MultiOutputClassifier
    from sklearn.naive_bayes import MultinomialNB
    from utils.data_util import load_pickle
    import os
    from pathConfig import data_dir
    from utils.vocab_util import vocab_to_index_dict
    import numpy as np
    
    # train & test data
    train_dir = os.path.join(data_dir, "train")
    test_dir = os.path.join(data_dir, "test")
    
    # train
    # classifier = SVC(kernel='linear', probability=True)
    # classifier = LogisticRegression()
    classifier = MultinomialNB()
    print("Training classifier ", str(classifier))
    clf = MultiOutputClassifier(classifier, n_jobs=24)
    
    for fname in os.listdir(train_dir):
        fpath = os.path.join(train_dir, fname)
        print("loading file ", fpath)
        train_X, train_y = load_train_file(fpath)
        print("partial_fiting...")
        clf.partial_fit(train_X, train_y, classes=[[0, 1]] * len(label_vocab))
        break
    
    # test
    test_X, test_y = load_test_data()
    
    # evaluate for each test file
    y_pred = clf.predict_proba(test_X)  # [n_tags, n_test_unit]
    
    y_pred_prcessed = []
    for i in range(len(test_X)):
        test_tmp = []
        for j in range(len(tag_vocab)):
            test_tmp.append(y_pred[j][i][0] * 0.5 + y_pred[j][i][1] * 0.5)  # because [0,1]
        y_pred_prcessed.append(np.array(test_tmp))
    y_pred_prcessed = np.array(y_pred_prcessed)
    
  • 相关阅读:
    模块
    迭代器,生成器
    内置函数
    匿名函数,三元表达式,列表推导式,字典生成式
    递归
    闭包函数,装饰器
    名称空间与作用域
    《数据结构与算法之6 扑克牌洗牌算法》
    《java学习笔记》
    Building Machine Learning Systems with Python 2
  • 原文地址:https://www.cnblogs.com/XBWer/p/13503796.html
Copyright © 2011-2022 走看看