zoukankan      html  css  js  c++  java
  • 【Spark机器学习速成宝典】模型篇07梯度提升树【Gradient-Boosted Trees】(Python版)

    目录

      梯度提升树原理

      梯度提升树代码(Spark Python)


    梯度提升树原理

       待续...

     返回目录

    梯度提升树代码(Spark Python) 

      

      代码里数据:https://pan.baidu.com/s/1jHWKG4I 密码:acq1

    # -*-coding=utf-8 -*-  
    from pyspark import SparkConf, SparkContext
    sc = SparkContext('local')
    
    from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
    from pyspark.mllib.util import MLUtils
    
    # Load and parse the data file.
    data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    '''
    每一行使用以下格式表示一个标记的稀疏特征向量
    label index1:value1 index2:value2 ...
    
    tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\n-1\n-1 2:4.0 4:5.0 6:6.0")
    >>> tempFile.flush()
    >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
    >>> tempFile.close()
    >>> examples[0]
    LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0]))
    >>> examples[1]
    LabeledPoint(-1.0, (6,[],[]))
    >>> examples[2]
    LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0]))
    '''
    # Split the data into training and test sets (30% held out for testing)  分割数据集,留30%作为测试集
    (trainingData, testData) = data.randomSplit([0.7, 0.3])
    
    # Train a GradientBoostedTrees model. 训练决策树模型
    #  Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. 空的categoricalFeaturesInfo意味着所有的特征都是连续的
    #         (b) Use more iterations in practice. 在实践中使用更多的迭代步数 
    model = GradientBoostedTrees.trainClassifier(trainingData,
                                                 categoricalFeaturesInfo={}, numIterations=30)
    
    # Evaluate model on test instances and compute test error 评估模型
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
    testErr = labelsAndPredictions.filter(
        lambda lp: lp[0] != lp[1]).count() / float(testData.count())
    print('Test Error = ' + str(testErr)) #Test Error = 0.0
    print('Learned classification GBT model:')
    print(model.toDebugString())
    '''
    TreeEnsembleModel classifier with 30 trees
    
      Tree 0:
        If (feature 434 <= 0.0)
         If (feature 100 <= 165.0)
          Predict: -1.0
         Else (feature 100 > 165.0)
          Predict: 1.0
        Else (feature 434 > 0.0)
         Predict: 1.0
      Tree 1:
        If (feature 490 <= 0.0)
         If (feature 549 <= 253.0)
          If (feature 184 <= 0.0)
           Predict: -0.4768116880884702
          Else (feature 184 > 0.0)
           Predict: -0.47681168808847024
         Else (feature 549 > 253.0)
          Predict: 0.4768116880884694
        Else (feature 490 > 0.0)
         If (feature 215 <= 251.0)
          Predict: 0.4768116880884701
         Else (feature 215 > 251.0)
          Predict: 0.4768116880884712
      ...
      Tree 29:
        If (feature 434 <= 0.0)
         If (feature 209 <= 4.0)
          Predict: 0.1335953290513215
         Else (feature 209 > 4.0)
          If (feature 372 <= 84.0)
           Predict: -0.13359532905132146
          Else (feature 372 > 84.0)
           Predict: -0.1335953290513215
        Else (feature 434 > 0.0)
         Predict: 0.13359532905132146
    '''
    # Save and load model
    model.save(sc, "myGradientBoostingClassificationModel")
    sameModel = GradientBoostedTreesModel.load(sc,"myGradientBoostingClassificationModel")
    print sameModel.predict(data.collect()[0].features) #0.0

     返回目录

  • 相关阅读:
    datanode报错Problem connecting to server
    使用命令查看hdfs的状态
    Access denied for user root. Superuser privilege is requ
    ElasticSearch默认的分页参数 size
    SparkStreaming Kafka 维护offset
    【容错篇】Spark Streaming的还原药水——Checkpoint
    251 Android 线性与相对布局简介
    250 Android Studio使用指南 总结
    249 如何解决项目导入产生的中文乱码问题
    248 gradle更新问题
  • 原文地址:https://www.cnblogs.com/itmorn/p/8028435.html
Copyright © 2011-2022 走看看