zoukankan      html  css  js  c++  java
  • scikit-learn使用方法

    1.支持向量机

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    
    #装载内部测试数据集
    digits = datasets.load_digits()
    #设置参数
    clf = svm.SVC(gamma = 0.001,C = 100.)
    #训练
    clf.fit(digits.data[:-1],digits.target[:-1])
    #预测
    print clf.predict(digits.data[-1:])
    

    想在scikit中保存模型的话,可以使用python的内置模块pickle

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    import pickle
    from sklearn.externals import joblib
    #装载内部测试数据集
    iris = datasets.load_iris()
    X,y = iris.data,iris.target
    #初始化模型
    clf = svm.SVC()
    #训练
    clf.fit(X[:-1],y[:-1])
    #保存模型
    s = pickle.dumps(clf)
    #装载模型
    clf2 = pickle.loads(s)
    #预测
    print clf2.predict(X[-1:])
    

    ※在数据量非常大的时候,我们需要把模型保存在硬盘上,而不是字符串中

    #_*_ coding:utf-8 _*_
    from sklearn import datasets
    from sklearn import svm
    from sklearn.externals import joblib
    #装载内部测试数据集
    iris = datasets.load_iris()
    X,y = iris.data,iris.target
    #初始化模型
    clf = svm.SVC()
    #训练
    clf.fit(X[:-1],y[:-1])
    #保存模型
    joblib.dump(clf,'filename.pkl')
    #装载模型
    clf2 = joblib.load('filename.pkl')
    #预测
    print clf2.predict(X[-1:])

    2.如无特殊说明,输入数据都被转换成float64位,在下面的例子中X可以通过fit_transform(X)转换成float64:

    #_*_ coding:utf-8 _*_
    
    import numpy as np
    from sklearn import random_projection
    
    rng = np.random.RandomState(0)
    
    X = rng.rand(10,2000)
    Y = np.array(X)
    X = np.array(X,dtype='float32')
    print Y.dtype,X.dtype
    
    transformer = random_projection.GaussianRandomProjection()
    X_new = transformer.fit_transform(X)
    print X_new.dtype
    

     3.重新装载并更新参数

    #_*_ coding:utf-8 _*_
    
    import numpy as np
    from sklearn.svm import SVC
    
    rng = np.random.RandomState(0)
    X = rng.rand(100,10)
    y = rng.binomial(1,0.5,100)
    X_test = rng.rand(5,10)
    
    clf = SVC()
    clf.set_params(kernel = 'linear').fit(X,y)
    
    print clf.predict(X_test)
    
    clf.set_params(kernel = 'rbf').fit(X,y)
    print clf.predict(X_test)
  • 相关阅读:
    单机RedHat6.5+JDK1.8+Hadoop2.7.3+Spark2.1.1+zookeeper3.4.6+kafka2.11+flume1.6环境搭建步骤
    kafka_2.11-0.8.2.1+java 生产消费程序demo示例
    Kafka使用log.retention.hours改变消息端的消息保存时间
    Apache Kafka监控之KafkaOffsetMonitor
    Apache Kafka监控之Kafka Web Console
    Kafka三款监控工具比较
    linux查看本机IP、gateway、dns
    kafka_2.11-0.8.2.1生产者producer的Java实现
    linux下杀死进程(kill)的N种方法
    Linux查看硬件配置命令
  • 原文地址:https://www.cnblogs.com/ryuham/p/5266012.html
Copyright © 2011-2022 走看看