zoukankan      html  css  js  c++  java
  • Python+Spark2.0+hadoop学习笔记——Python Spark MLlib Naive Bayes二分类

    朴素贝叶斯是一种经典的分类方法,其原理在高中或大学的概率论部分学习了很多了,下面开始介绍在Spark环境下使用MLlib来使用Naive Bayes来对网站性质进行分类判断。

    第一步:导入库函数

    import sys
    from time import time
    import pandas as pd
    import matplotlib.pyplot as plt
    from pyspark import SparkConf, SparkContext
    from pyspark.mllib.classification import NaiveBayes
    from pyspark.mllib.regression import LabeledPoint
    import numpy as np
    from pyspark.mllib.evaluation import BinaryClassificationMetrics
    from pyspark.mllib.feature import StandardScaler

    第二步:数据准备

    def get_mapping(rdd, idx):
    return rdd.map(lambda fields: fields[idx]).distinct().zipWithIndex().collectAsMap()

    def extract_label(record):
    label=(record[-1])
    return float(label)

    def extract_features(field,categoriesMap,featureEnd):
    categoryIdx = categoriesMap[field[3]]
    categoryFeatures = np.zeros(len(categoriesMap))
    categoryFeatures[categoryIdx] = 1
    numericalFeatures=[convert_float(field) for field in field[4: featureEnd]]
    return np.concatenate(( categoryFeatures, numericalFeatures))

    def convert_float(x):
    ret=(0 if x=="?" else float(x))
    return(0 if ret<0 else ret)

    def PrepareData(sc):
    print("Data loading...")
    rawDataWithHeader = sc.textFile(Path+"data/train.tsv")
    header = rawDataWithHeader.first()
    rawData = rawDataWithHeader.filter(lambda x:x !=header)
    rData=rawData.map(lambda x: x.replace(""", ""))
    lines = rData.map(lambda x: x.split(" "))
    print("The number of data" + str(lines.count()))
    print("Before normalization:")
    categoriesMap = lines.map(lambda fields: fields[3]).
    distinct().zipWithIndex().collectAsMap()
    labelRDD = lines.map(lambda r: extract_label(r))
    featureRDD = lines.map(lambda r: extract_features(r,categoriesMap,len(r) - 1))
    for i in featureRDD.first():
    print (str(i)+","),
    print( "After normalization:" )
    stdScaler = StandardScaler(withMean=False, withStd=True).fit(featureRDD)
    ScalerFeatureRDD=stdScaler.transform(featureRDD)
    for i in ScalerFeatureRDD.first():
    print (str(i)+","),

    labelpoint=labelRDD.zip(ScalerFeatureRDD)
    labelpointRDD=labelpoint.map(lambda r: LabeledPoint(r[0], r[1]))

    (trainData, validationData, testData) = labelpointRDD.randomSplit([8, 1, 1])
    print("trainData:" + str(trainData.count()) +
    "validationData:" + str(validationData.count()) +
    "testData:" + str(testData.count()))
    return (trainData, validationData, testData, categoriesMap)

    第三步:对模型进行训练

    def PredictData(sc,model,categoriesMap):
    print("Data loading...")
    rawDataWithHeader = sc.textFile(Path+"data/test.tsv")
    header = rawDataWithHeader.first()
    rawData = rawDataWithHeader.filter(lambda x:x !=header)
    rData=rawData.map(lambda x: x.replace(""", ""))
    lines = rData.map(lambda x: x.split(" "))
    print("The number of data" + str(lines.count()))
    dataRDD = lines.map(lambda r: ( r[0] ,
    extract_features(r,categoriesMap,len(r) )))
    DescDict = {
    0: "ephemeral",
    1: "evergreen"
    }
    for data in dataRDD.take(10):
    predictResult = model.predict(data[1])
    print ("Web:" +str(data[0])+" " +
    "Predict:"+ str(predictResult)+
    "Illustration:"+DescDict[predictResult] +" ")

    第四步:对模型进行评估(NB模型只需要调节一个参数lambda)

    def evaluateModel(model, validationData):
    score = model.predict(validationData.map(lambda p: p.features))
    score = score.map(lambda score : float(score))
    Labels = validationData.map(lambda p: p.label)
    Labels = Labels.map(lambda Labels : float(Labels))
    scoreAndLabels=score.zip(Labels)
    metrics = BinaryClassificationMetrics(scoreAndLabels)
    AUC=metrics.areaUnderROC
    return(AUC)

    def trainEvaluateModel(trainData,validationData,lambdaParam):
    startTime = time()
    model = NaiveBayes.train(trainData, lambdaParam)
    AUC = evaluateModel(model, validationData)
    duration = time() - startTime
    print(" lambda="+str( lambdaParam) +
    " time="+str(duration) +
    " AUC = " + str(AUC) )
    return (AUC,duration, lambdaParam,model)

    def evalParameter(trainData, validationData, evalparm,
    lambdaParamList):
    metrics = [trainEvaluateModel(trainData, validationData,regParam )
    for regParam in lambdaParamList]
    evalparm="lambdaParam"
    IndexList=lambdaParamList
    df = pd.DataFrame(metrics,index=IndexList,
    columns=['AUC', 'duration',' lambdaParam','model'])
    showchart(df,evalparm,'AUC','duration',0.5,0.7 )

    def showchart(df,evalparm ,barData,lineData,yMin,yMax):
    ax = df[barData].plot(kind='bar', title =evalparm,figsize=(10,6),legend=True, fontsize=12)
    ax.set_xlabel(evalparm,fontsize=12)
    ax.set_ylim([yMin,yMax])
    ax.set_ylabel(barData,fontsize=12)
    ax2 = ax.twinx()
    ax2.plot(df[[lineData ]].values, linestyle='-', marker='o', linewidth=2.0,color='r')
    plt.show()

    def evalAllParameter(training_RDD, validation_RDD, lambdaParamList):
    metrics = [trainEvaluateModel(trainData, validationData, lambdaParam )
    for lambdaParam in lambdaParamList ]
    Smetrics = sorted(metrics, key=lambda k: k[0], reverse=True)
    bestParameter=Smetrics[0]

    print("lambdaParam:" + str(bestParameter[2]) +
    "AUC = " + str(bestParameter[0]))
    return bestParameter[3]

    def parametersEval(trainData, validationData):
    print("For evaluating lambdaParam")
    evalParameter(trainData, validationData,"lambdaParam",
    lambdaParamList=[1.0, 3.0, 5.0, 15.0, 25.0,30.0,35.0,40.0,45.0,50.0,60.0])

    第五步:Spark相关设置

    def SetLogger( sc ):
    logger = sc._jvm.org.apache.log4j
    logger.LogManager.getLogger("org"). setLevel( logger.Level.ERROR )
    logger.LogManager.getLogger("akka").setLevel( logger.Level.ERROR )
    logger.LogManager.getRootLogger().setLevel(logger.Level.ERROR)

    def SetPath(sc):
    global Path
    if sc.master[0:5]=="local" :
    Path="file:/home/jorlinlee/pythonsparkexample/PythonProject/"
    else:
    Path="hdfs://master:9000/user/jorlinlee/"

    def CreateSparkContext():
    sparkConf = SparkConf()
    .setAppName("NB")
    .set("spark.ui.showConsoleProgress", "false")
    sc = SparkContext(conf = sparkConf)
    print ("master="+sc.master)
    SetLogger(sc)
    SetPath(sc)
    return (sc)

    sc.stop()

    第六步:运行主程序

    if __name__ == "__main__":
    print("NB")
    sc=CreateSparkContext()
    print("Preparing")
    (trainData, validationData, testData, categoriesMap) =PrepareData(sc)
    trainData.persist(); validationData.persist(); testData.persist()
    print("Evaluating")
    (AUC,duration, lambdaParam,model)=
    trainEvaluateModel(trainData, validationData, 60.0)
    if (len(sys.argv) == 2) and (sys.argv[1]=="-e"):
    parametersEval(trainData, validationData)
    elif (len(sys.argv) == 2) and (sys.argv[1]=="-a"):
    print("Best parameter")
    model=evalAllParameter(trainData, validationData,
    [1.0, 3.0, 5.0, 15.0, 25.0,30.0,35.0,40.0,45.0,50.0,60.0])
    print("Test")
    auc = evaluateModel(model, testData)
    print("AUC:" + str(auc))
    print("Predict")
    PredictData(sc, model, categoriesMap)

    结果:

    Web:http://www.lynnskitchenadventures.com/2009/04/homemade-enchilada-sauce.html
    Predict:1.0Illustration:evergreen

    Web:http://lolpics.se/18552-stun-grenade-ar
    Predict:1.0Illustration:evergreen

    Web:http://www.xcelerationfitness.com/treadmills.html
    Predict:1.0Illustration:evergreen

    Web:http://www.bloomberg.com/news/2012-02-06/syria-s-assad-deploys-tactics-of-father-to-crush-revolt-threatening-reign.html
    Predict:1.0Illustration:evergreen

    Web:http://www.wired.com/gadgetlab/2011/12/stem-turns-lemons-and-limes-into-juicy-atomizers/
    Predict:1.0Illustration:evergreen

    Web:http://www.latimes.com/health/boostershots/la-heb-fat-tax-denmark-20111013,0,2603132.story
    Predict:1.0Illustration:evergreen

    Web:http://www.howlifeworks.com/a/a?AG_ID=1186&cid=7340ci
    Predict:1.0Illustration:evergreen

    Web:http://romancingthestoveblog.wordpress.com/2010/01/13/sweet-potato-ravioli-with-lemon-sage-brown-butter-sauce/
    Predict:1.0Illustration:evergreen

    Web:http://www.funniez.net/Funny-Pictures/turn-men-down.html
    Predict:1.0Illustration:evergreen

    Web:http://youfellasleepwatchingadvd.com/
    Predict:1.0Illustration:evergreen

  • 相关阅读:
    工作流二次开发之新增表单实践(二)
    layui表格及工作流二次开发实践(一)
    记一个递归封装树形结构
    SpringCloud微服务之宏观了解
    统一结果返回&统一异常处理
    mybatis-Plus 实践篇之CRUD操作
    修改MariaDB-root密码
    iftop-监控服务器实时带宽情况
    Wordpress安装-报错说明
    MariaDB忘记root密码
  • 原文地址:https://www.cnblogs.com/zhuozige/p/12629537.html
Copyright © 2011-2022 走看看