zoukankan      html  css  js  c++  java
  • K-fold Train Version3

    # config.py
    TRAINING_FILE = "../input/mnist_train_folds.csv"
    MODEL_OUTPUT = "../models/"
    # model_dispatcher.py
    from sklearn import tree
    from sklearn import ensemble
    models = {
    "decision_tree_gini": tree.DecisionTreeClassifier(
    criterion="gini"
    ),
    "decision_tree_entropy": tree.DecisionTreeClassifier(
    criterion="entropy"
    ),
    "rf": ensemble.RandomForestClassifier(),
    }
    # train.py
    import argparse
    import os
    import joblib
    import pandas as pd
    from sklearn import metrics
    import config
    import model_dispatcher

    def run(fold, model):
    # read the training data with folds
    df = pd.read_csv(config.TRAINING_FILE)
    # training data is where kfold is not equal to provided fold
    # also, note that we reset the index
    df_train = df[df.kfold != fold].reset_index(drop=True)
    # validation data is where kfold is equal to provided fold
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    # drop the label column from dataframe and convert it to
    # a numpy array by using .values.
    # target is label column in the dataframe
    x_train = df_train.drop("label", axis=1).values
    y_train = df_train.label.values
    # similarly, for validation, we have
    x_valid = df_valid.drop("label", axis=1).values
    y_valid = df_valid.label.values
    # fetch the model from model_dispatcher
    clf = model_dispatcher.models[model]
    # fir the model on training data
    clf.fit(x_train, y_train)
    # create predictions for validation samples
    preds = clf.predict(x_valid)
    # calculate & print accuracy
    accuracy = metrics.accuracy_score(y_valid, preds)
    print(f"Fold={fold}, Accuracy={accuracy}")
    # save the model
    joblib.dump(
    clf,
    os.path.join(config.MODEL_OUTPUT, f"dt_{fold}.bin")
    )


    if __name__ == "__main__":
    # initialize ArgumentParser class of argparse
    parser = argparse.ArgumentParser()
    # add the different arguments needed and their type
    # currently, only need fold
    parser.add_argument(
    "--fold",
    type=int
    )
    parser.add_argument(
    "--model",
    type=str
    )
    # read the arguments from the command line
    args = parser.parse_args()
    # run the fold specified by command line arguments
    run(fold=args.fold,
    model=args.model
    )
    ================================================
    #!/bin/sh
    # run.sh
    python train.py --fold 0 --model rf
    python train.py --fold 1 --model rf
    python train.py --fold 2 --model rf
    python train.py --fold 3 --model rf
    python train.py --fold 4 --model rf
    ================================================
    sh run.sh
  • 相关阅读:
    terminal下历史命令自动完成功能history auto complete
    Shell中while循环的done 后接一个重定向<
    python 链接hive
    shell 学习基地
    c++ 获取本地ip地址
    c++ 如何实现,readonly
    c++ 语法
    c++ 信号量
    vim插件介绍
    c++ memset 函数 及 坑
  • 原文地址:https://www.cnblogs.com/songyuejie/p/14789476.html
Copyright © 2011-2022 走看看