zoukankan      html  css  js  c++  java
  • 朴素贝叶斯算法原理及Spark MLlib实例(Scala/Java/Python)

    朴素贝叶斯

    算法介绍:

    朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法。

    朴素贝叶斯的思想基础是这样的:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,在没有其它可用信息下,我们会选择条件概率最大的类别作为此待分类项应属的类别。

    朴素贝叶斯分类的正式定义如下:

    1、设 为一个待分类项,而每个a为x的一个特征属性。

    2、有类别集合 。

    3、计算 。

    4、如果 ,则 。

    那么现在的关键就是如何计算第3步中的各个条件概率。我们可以这么做:

    1、找到一个已知分类的待分类项集合,这个集合叫做训练样本集。

    2、统计得到在各类别下各个特征属性的条件概率估计。即 

    3、如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导:

     

    因为分母对于所有类别为常数,因为我们只要将分子最大化皆可。又因为各特征属性是条件独立的,所以有:

     

    spark.ml现在支持多项朴素贝叶斯和伯努利朴素贝叶斯。

    参数:

    featuresCol:

    类型:字符串型。

    含义:特征列名。

    labelCol:

    类型:字符串型。

    含义:标签列名。

    modelType:

    类型:字符串型。

    含义:模型类型(区分大小写)。

    predictionCol:

    类型:字符串型。

    含义:预测结果列名。

    probabilityCol:

    类型:字符串型。

    含义:用以预测类别条件概率的列名。

    rawPredictionCol:

    类型:字符串型。

    含义:原始预测。

    smoothing:

    类型:双精度型。

    含义:平滑参数。

    thresholds:

    类型:双精度数组型。

    含义:多分类预测的阀值,以调整预测结果在各个类别的概率。

    示例:

    Scala:

    import org.apache.spark.ml.classification.NaiveBayes  
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator  
      
    // Load the data stored in LIBSVM format as a DataFrame.  
    val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")  
      
    // Split the data into training and test sets (30% held out for testing)  
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)  
      
    // Train a NaiveBayes model.  
    val model = new NaiveBayes()  
      .fit(trainingData)  
      
    // Select example rows to display.  
    val predictions = model.transform(testData)  
    predictions.show()  
      
    // Select (prediction, true label) and compute test error  
    val evaluator = new MulticlassClassificationEvaluator()  
      .setLabelCol("label")  
      .setPredictionCol("prediction")  
      .setMetricName("accuracy")  
    val accuracy = evaluator.evaluate(predictions)  
    println("Accuracy: " + accuracy)  

    Java:

    import org.apache.spark.ml.classification.NaiveBayes;  
    import org.apache.spark.ml.classification.NaiveBayesModel;  
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;  
    import org.apache.spark.sql.Dataset;  
    import org.apache.spark.sql.Row;  
    import org.apache.spark.sql.SparkSession;  
      
    // Load training data  
    Dataset<Row> dataFrame =  
      spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");  
    // Split the data into train and test  
    Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);  
    Dataset<Row> train = splits[0];  
    Dataset<Row> test = splits[1];  
      
    // create the trainer and set its parameters  
    NaiveBayes nb = new NaiveBayes();  
    // train the model  
    NaiveBayesModel model = nb.fit(train);  
    // compute accuracy on the test set  
    Dataset<Row> result = model.transform(test);  
    Dataset<Row> predictionAndLabels = result.select("prediction", "label");  
    MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()  
      .setMetricName("accuracy");  
    System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));  

    Python:

    from pyspark.ml.classification import NaiveBayes  
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator  
      
    # Load training data  
    data = spark.read.format("libsvm")   
        .load("data/mllib/sample_libsvm_data.txt")  
    # Split the data into train and test  
    splits = data.randomSplit([0.6, 0.4], 1234)  
    train = splits[0]  
    test = splits[1]  
      
    # create the trainer and set its parameters  
    nb = NaiveBayes(smoothing=1.0, modelType="multinomial")  
      
    # train the model  
    model = nb.fit(train)  
    # compute accuracy on the test set  
    result = model.transform(test)  
    predictionAndLabels = result.select("prediction", "label")  
    evaluator = MulticlassClassificationEvaluator(metricName="accuracy")  
    print("Accuracy: " + str(evaluator.evaluate(predictionAndLabels)))  
  • 相关阅读:
    echo e 在SHELL脚本和命令行中表现不同一例问题排查
    Linux 中修改网卡名称【ubuntu + Centos7】
    ESXI上实施ORACLE 10G RAC+LINUX+ASM
    Linux crontab下关于使用date命令的坑
    SkiaSharp跨平台绘图研究1WPF桌面应用
    编译原理 实验一 词法分析
    计算机组成原理(上)_第一章测试题
    计算机组成原理(上)_第三章测试题
    SQL Server 2017 下载及安装详细教程
    计算机组成原理(上)_第四章测试题(上)
  • 原文地址:https://www.cnblogs.com/itboys/p/9172653.html
Copyright © 2011-2022 走看看