zoukankan      html  css  js  c++  java
  • 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步。

    比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要根据训练好的房价模型来预测用户房子的价格。

    这样就需要在训练模型后把模型保存起来,在使用模型时把模型读取出来对输入的数据进行预测。

    这里保存和读取模型有两种方法,都非常简单,差别在于保存和读取速度的快慢上,因为有一个是利用了多进程机制,下面我们分别来看一下。

    创建模型

    首先我们创建模型并训练数据:

    from sklearn.datasets import load_digits
    from sklearn.svm import SVC
    
    # 加载数据
    digits = load_digits()
    X = digits.data
    y = digits.target
    
    model = SVC()
    model.fit(X, y)

    用pickle读写模型

    pickle是python中用于数据序列化的模块,因此,对于模型的序列化也可以用此模块来进行:

    import pickle
    # 以写二进制的方式打开文件
    file = open("D:/data/python/model.pickle", "wb")
    # 把模型写入到文件中
    pickle.dump(model, file)
    # 关闭文件
    file.close()

    这样会创建D:/data/python/model.pickle的文件,大家可以自己去尝试下看看,我这边生成的文件大概1M左右。

    有了模型文件之后,在进行预测时我们就不需要进行训练了,而只要把这个训练好的模型文件读取出来,然后直接进行预测就可以:

    import pickle
    # 以读二进制的方式打开文件
    file = open("D:/data/python/model.pickle", "rb")
    # 把模型从文件中读取出来
    model = pickle.load(file)
    # 关闭文件
    file.close()
    
    # 用模型进行预测
    from sklearn.datasets import load_digits
    digits = load_digits()
    X = digits.data
    y = digits.target
    
    print("预测值:", model.predict(X[15:20]))
    print("实际值:", y[15:20])

    输出为:

    预测值: [5 6 7 8 9]
    实际值: [5 6 7 8 9]

    用joblib进行模型的读写

    直接上代码:

    from sklearn.datasets import load_digits
    from sklearn.svm import SVC
    
    # 用模型进行训练
    digits = load_digits()
    X = digits.data
    y = digits.target
    model = SVC()
    model.fit(X, y)
    
    # 用joblib保存模型
    from sklearn.externals import joblib
    joblib.dump(model, "D:/data/python/model.joblib")

    这样就会生成D:/data/python/model.joblib文件,看起来比pickle生成的文件大一点点。

    读取模型:

    # 用joblib读取模型
    from sklearn.externals import joblib
    model = joblib.load("D:/data/python/model.joblib")
    
    # 对数据进行预测
    from sklearn.datasets import load_digits
    digits = load_digits()
    X = digits.data
    y = digits.target
    
    print("预测值:", model.predict(X[15:20]))
    print("实际值:", y[15:20])

    输出为:

    预测值: [5 6 7 8 9]
    实际值: [5 6 7 8 9]

    看起来也很简单,同pickle的区别是joblib会以多进程方式来进行,据说性能会好些。

  • 相关阅读:
    【Henu ACM Round#17 A】Simple Game
    【Henu ACM Round #12 E】Thief in a Shop
    【Henu ACM Round#16 D】Bear and Two Paths
    【Henu ACM Round#16 A】 Bear and Game
    P4824 [USACO15FEB]Censoring (Silver) 审查(银)
    P4001 [BJOI2006]狼抓兔子
    P2444 [POI2000]病毒
    P3966 [TJOI2013]单词
    P3796 【模板】AC自动机
    P4574 [CQOI2013]二进制A+B
  • 原文地址:https://www.cnblogs.com/onemorepoint/p/8144207.html
Copyright © 2011-2022 走看看