• 为了保证你在浏览本网站时有着更好的体验,建议使用类似Chrome、Firefox之类的浏览器~~
    • 如果你喜欢本站的内容何不Ctrl+D收藏一下呢,与大家一起分享各种编程知识~
    • 本网站研究机器学习、计算机视觉、模式识别~当然不局限于此,生命在于折腾,何不年轻时多折腾一下

spark实现gbdt和lr

bigdata admin 2年前 (2018-03-19) 4901次浏览 2个评论 扫描二维码

spark 对 python 开放的接口实在是有限,只有 scala 是亲生的。查了下 scala 的包和函数,发现提供的真全,博主从零开始撸 scala 代码,边写边查的节奏,给出以下 example 代码给大家参考

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, FeatureType, Strategy}
import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession

object gbdt_lr {
  //get decision tree leaf's nodes
  def getLeafNodes(node: Node): Array[Int] = {
    var treeLeafNodes = new Array[Int](0)
    if (node.isLeaf) {
      treeLeafNodes = treeLeafNodes.:+(node.id)
    } else {
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get)
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get)
    }
    treeLeafNodes
  }

  // predict decision tree leaf's node value
  def predictModify(node: Node, features: DenseVector): Int = {
    val split = node.split
    if (node.isLeaf) {
      node.id
    } else {
      if (split.get.featureType == FeatureType.Continuous) {
        if (features(split.get.feature) <= split.get.threshold) {
          //          println("Continuous left node")
          predictModify(node.leftNode.get, features)
        } else {
          //          println("Continuous right node")
          predictModify(node.rightNode.get, features)
        }
      } else {
        if (split.get.categories.contains(features(split.get.feature))) {
          //          println("Categorical left node")
          predictModify(node.leftNode.get, features)
        } else {
          //          println("Categorical right node")
          predictModify(node.rightNode.get, features)
        }
      }
    }
  }

  def main(args: Array[String]) {

    val sparkConf = new SparkConf().setAppName("GbdtAndLr")
    sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    val sampleDir = "/Users/leiyang/IdeaProjects/spark_2.3/src/watermelon3_0_En.csv"
    val sc = new SparkContext(sparkConf)
    val spark = SparkSession.builder.config(sparkConf).getOrCreate()
    val dataFrame = spark.read.format("CSV").option("header", "true").load(sampleDir)

    val data = dataFrame.rdd.map { x =>
      LabeledPoint(x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
        x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
    }
    val splits = data.randomSplit(Array(0.8, 0.2))
    val train = splits(0)
    val test = splits(1)
    //
    //    // GBDT Model
    val numTrees = 2
    val boostingStrategy = BoostingStrategy.defaultParams("Classification")
    boostingStrategy.setNumIterations(numTrees)
    val treeStratery = Strategy.defaultStrategy("Classification")
    treeStratery.setMaxDepth(5)
    treeStratery.setNumClasses(2)
    //    treeStratery.setCategoricalFeaturesInfo(Map[Int, Int]())
    boostingStrategy.setTreeStrategy(treeStratery)
    val gbdtModel = GradientBoostedTrees.train(train, boostingStrategy)
    //    val gbdtModelDir = args(2)
    //    gbdtModel.save(sc, gbdtModelDir)
    val labelAndPreds = test.map { point =>
      val prediction = gbdtModel.predict(point.features)
      (point.label, prediction)
    }
    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count()
    println("Test Error = " + testErr)
    //    println("Learned classification GBT model:\n" + gbdtModel.toDebugString)

    val treeLeafArray = new Array[Array[Int]](numTrees)
    for (i <- 0.until(numTrees)) {
      treeLeafArray(i) = getLeafNodes(gbdtModel.trees(i).topNode)
    }
    for (i <- 0.until(numTrees)) {
      println("正在打印第%d 棵树的 topnode 叶子节点", i)
      for (j <- 0.until(treeLeafArray(i).length)) {
        println(j)
      }

    }
    //    gbdt 构造新特征
    val newFeatureDataSet = dataFrame.rdd.map { x =>
      (x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
        x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
    }.map { x =>
      var newFeature = new Array[Double](0)
      for (i <- 0.until(numTrees)) {
        val treePredict = predictModify(gbdtModel.trees(i).topNode, x._2)
        //gbdt tree is binary tree
        val treeArray = new Array[Double]((gbdtModel.trees(i).numNodes + 1) / 2)
        treeArray(treeLeafArray(i).indexOf(treePredict)) = 1
        newFeature = newFeature ++ treeArray
      }
      (x._1, newFeature)
    }
    val newData = newFeatureDataSet.map(x => LabeledPoint(x._1, new DenseVector(x._2)))
    val splits2 = newData.randomSplit(Array(0.8, 0.2))
    val train2 = splits2(0)
    val test2 = splits2(1)

    val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(train2).setThreshold(0.01)
    model.weights
    val predictionAndLabels = test2.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val precision = metrics.accuracy
    println("Precision = " + precision)

    sc.stop()
  }
}


Deeplearn, 版权所有丨如未注明 , 均为原创丨本网站采用BY-NC-SA协议进行授权 , 转载请注明spark 实现 gbdt 和 lr
喜欢 (6)
admin
关于作者:
互联网行业码农一枚/业余铲屎官/数码影音爱好者/二次元

您必须 登录 才能发表评论!

(2)个小伙伴在吐槽
  1. 博主 您好 我运行您的这个代码时报错,我推测您是对数据做过预处理,您这个代码对应的数据能提供一下吗?谢谢
    • admin
      这个用的就是周志华机器学习那本书中的西瓜数据集
      admin2019-02-11 19:28