zoukankan      html  css  js  c++  java
  • scikit-learn使用方法

    1.支持向量机

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    
    #装载内部测试数据集
    digits = datasets.load_digits()
    #设置参数
    clf = svm.SVC(gamma = 0.001,C = 100.)
    #训练
    clf.fit(digits.data[:-1],digits.target[:-1])
    #预测
    print clf.predict(digits.data[-1:])
    

    想在scikit中保存模型的话,可以使用python的内置模块pickle

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    import pickle
    from sklearn.externals import joblib
    #装载内部测试数据集
    iris = datasets.load_iris()
    X,y = iris.data,iris.target
    #初始化模型
    clf = svm.SVC()
    #训练
    clf.fit(X[:-1],y[:-1])
    #保存模型
    s = pickle.dumps(clf)
    #装载模型
    clf2 = pickle.loads(s)
    #预测
    print clf2.predict(X[-1:])
    

    ※在数据量非常大的时候,我们需要把模型保存在硬盘上,而不是字符串中

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    from sklearn.externals import joblib
    #装载内部测试数据集
    iris = datasets.load_iris()
    X,y = iris.data,iris.target
    #初始化模型
    clf = svm.SVC()
    #训练
    clf.fit(X[:-1],y[:-1])
    #保存模型
    joblib.dump(clf,'filename.pkl')
    #装载模型
    clf2 = joblib.load('filename.pkl')
    #预测
    print clf2.predict(X[-1:])

    2.如无特殊说明,输入数据都被转换成float64位,在下面的例子中X可以通过fit_transform(X)转换成float64:

    #_*_ coding:utf-8 _*_
    
    import numpy as np
    from sklearn import random_projection
    
    rng = np.random.RandomState(0)
    
    X = rng.rand(10,2000)
    Y = np.array(X)
    X = np.array(X,dtype='float32')
    print Y.dtype,X.dtype
    
    transformer = random_projection.GaussianRandomProjection()
    X_new = transformer.fit_transform(X)
    print X_new.dtype
    

     3.重新装载并更新参数

    #_*_ coding:utf-8 _*_
    
    import numpy as np
    from sklearn.svm import SVC
    
    rng = np.random.RandomState(0)
    X = rng.rand(100,10)
    y = rng.binomial(1,0.5,100)
    X_test = rng.rand(5,10)
    
    clf = SVC()
    clf.set_params(kernel = 'linear').fit(X,y)
    
    print clf.predict(X_test)
    
    clf.set_params(kernel = 'rbf').fit(X,y)
    print clf.predict(X_test)
  • 相关阅读:
    vue_路由
    vue_列表动画
    vue生命周期详细
    Vue_过渡和动画
    vue_品牌列表案例(添加删除搜索过滤)
    vue_简单的添加删除
    v-if v-show
    vue_简单的添加数据
    JSON.parse()和JSON.stringify()
    vue_计算器
  • 原文地址:https://www.cnblogs.com/ryuham/p/5266012.html
Copyright © 2011-2022 走看看