zoukankan      html  css  js  c++  java
  • Use Spark sklearn 100x faster

    skdist: https://github.com/Ibotta/sk-dist

    import time
    from sklearn import datasets, svm
    from skdist.distribute.search import DistGridSearchCV
    from pyspark.sql import SparkSession 
    # instantiate spark session
    spark = (   
        SparkSession    
        .builder    
        .getOrCreate()    
        )
    sc = spark.sparkContext 
    # the digits dataset
    digits = datasets.load_digits()
    X = digits["data"]
    y = digits["target"] 
    # create a classifier: a support vector classifier
    classifier = svm.SVC()
    param_grid = {
        "C": [0.01, 0.01, 0.1, 1.0, 10.0, 20.0, 50.0], 
        "gamma": ["scale", "auto", 0.001, 0.01, 0.1], 
        "kernel": ["rbf", "poly", "sigmoid"]
        }
    scoring = "f1_weighted"
    cv = 10
    # hyperparameter optimization
    start = time.time()
    model = DistGridSearchCV(    
        classifier, param_grid,     
        sc=sc, cv=cv, scoring=scoring,
        verbose=True    
        )
    model.fit(X,y)
    print("Train time: {0}".format(time.time() - start))
    print("Best score: {0}".format(model.best_score_))
    ------------------------------
    Spark context found; running with spark
    Fitting 10 folds for each of 105 candidates, totalling 1050 fits
    Train time: 3.380601406097412
    Best score: 0.981450024203508
  • 相关阅读:
    Python-异常处理
    进程及其状态
    操作系统基础
    计算机组成基础
    Java wait()、notify()、notifyAll()方法
    Java 死锁
    线程同步
    Java 创建多线程
    Java 接口
    抽象类和抽象方法
  • 原文地址:https://www.cnblogs.com/similarface/p/13031759.html
Copyright © 2011-2022 走看看