spark随机森林算法 由多个决策树构成的森林,算法分类结果由这些决策树投票得到,决策树在生成的过程当中分别在行方向和列方向上添加随机过程,行方向上构建决策树时采用放回抽样(bootstraping)得到训练数据,列方向上采用无放回随机抽样得到特征子集,并据此得到其最优切分点,这便是随机森林算法的基本原理。图 1给出了随机森林算法分类原理,从图中可以看到,随机森林是一个组合模型,内部仍然是基于决策树,同单一的决策树分类不同的是,随机森林通过多个决策树投票结果进行分类,算法不容易出现过度拟合问题。 图 1 随机森林示意图
随机森林在分布式环境下的优化策略 随机森林算法在单机环境下很容易实现,但在分布式环境下特别是在 Spark 平台上,传统单机形式的迭代方式必须要进行相应改进才能适用于分布式环境,这是因为在分布式环境下,数据也是分布式的(如图 3所示),算法设计不得当会生成大量的 IO 操作,例如频繁的网络数据传输,从而影响算法效率。 图 2单机环境下数据存储
图3 分布式环境下数据存储
此,在 Spark 上进行随机森林算法的实现,需要进行一定的优化,Spark 中的随机森林算法主要实现了三个优化策略: . 切分点抽样统计,如图 4所示。在单机环境下的决策树对连续变量进行切分点选择时,一般是通过对特征点进行排序,然后取相邻两个数之间的点作为切分点,这在单机环境下是可行的,但如果在分布式环境下如此操作的话,会带来大量的网络传输操作,特别是当数据量达到 PB 级时,算法效率将极为低下。为避免该问题,Spark 中的随机森林在构建决策树时,会对各分区采用一定的子特征策略进行抽样,然后生成各个分区的统计数据,并最终得到切分点。 . 特征装箱(Binning),如图 5 所示。决策树的构建过程就是对特征的取值不断进行划分的过程,对于离散的特征,如果有 M 个值,最多$$2^{m-1}-1$$
个划分,如果值是有序的,那么就最多 M-1 个划分。比如年龄特征,有老,中,少 3 个值,如果无序有 个,即 3 种划分:老|中,少;老,中|少;老,少|中;如果是有序的,即按老,中,少的序,那么只有 m-1 个,即 2 种划分,老|中,少;老,中|少。对于连续的特征,其实就是进行范围划分,而划分的点就是 split(切分点),划分出的区间就是 bin。对于连续特征,理论上 split 是无数的,在分布环境下不可能取出所有的值,因此它采用的是(1)中的切点抽样统计方法。 举个例子说明箱化处理过程在数据量很大的情况下,对每个可能取值进行排序成本就太大了,一个惯用的近似技巧——“箱化” (Binning,我们觉得取值太多了,一个个处理太麻烦,打包起来,处理就简单多了)。 假如原来特征的取值是: 1,2,3,4,5,6,7,8,9,…. 经过打包后可能就变成如下: (1,2),(3,4),(5,6),(7,8),…. 然后以箱子为单位,根据每个箱子的最小值和最大值,可以确定划分边界,然后按照信息增益或者其他衡量方式确定最终分裂边界。那么采用这种技巧之后,Bin个数最多就是所有取值的情况的个数,也就是M(当于没有使用”箱化”技巧),而Split的个数就是M-1 . 逐层训练(level-wise training),如图 8 所示。单机版本的决策数生成过程是通过递归调用(本质上是深度优先)的方式构造树,在构造树的同时,需要移动数据,将同一个子节点的数据移动到一起。此方法在分布式数据结构上无法有效的执行,而且也无法执行,因为数据太大,无法放在一起,所以在分布式环境下采用的策略是逐层构建树节点(本质上是广度优先),这样遍历所有数据的次数等于所有树中的最大层数。每次遍历时,只需要计算每个节点所有切分点统计参数,遍历完后,根据节点的特征划分,决定是否切分,以及如何切分。 图 4切分点抽样统计
图5 特征装箱
图 6逐层训练
算法源码分析 在对决策树、随机森林算法原理及 Spark 上的优化策略的理解基础上,本节将对 Spark MLlib 中的随机森林算法源码进行分析。首先给出了官网上的算法使用 demo,然后再深入到对应方法源码中,对实现原理进行分析。 参数解析 checkpointInterval: 类型:整数型。 含义:设置检查点间隔(>=1),或不设置检查点(-1)。 featuresCol: 类型:字符串型。 含义:特征列名。 impurity: 类型:字符串型。 含义:计算信息增益的准则(不区分大小写)。 labelCol: 类型:字符串型。 含义:标签列名。 maxBins: 类型:整数型。 含义:连续特征离散化的最大数量,以及选择每个节点分裂特征的方式。 maxDepth: 类型:整数型。 含义:树的最大深度(>=0)。 minInfoGain: 类型:双精度型。 含义:分裂节点时所需最小信息增益。 minInstancesPerNode: 类型:整数型。 含义:分裂后自节点最少包含的实例数量。 predictionCol: 类型:字符串型。 含义:预测结果列名。 seed: 类型:长整型。 含义:随机种子。 varianceCol: 类型:字符串型。 含义:预测的有偏样本偏差的列名。 numTrees 类型:整形 含义:训练的决策树的个数
实验代码如下 平台spark1.6
from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils from pyspark import SparkConf,SparkContext # Load and parse the data file into an RDD of LabeledPoint. conf=SparkConf().setAppName('decsion').setMaster('local') sc=SparkContext(conf=conf) #加载本地文件 data = MLUtils.loadLibSVMFile(sc, ‘file:///yourpath/sample_libsvm_data.txt') #训练集占据70%,测试集30% (trainingData, testData) = data.randomSplit([0.7, 0.3]) #训练随机森林模型 # categoricalFeaturesInfo 为空表示属性为连续属性 # 在实际应用中使用较大的numTrees参数 # featureSubsetStrategy="auto" 让算法自动选取每个节点划分时需要考虑的属性个数 model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, numTrees=3, featureSubsetStrategy="auto", impurity='gini', maxDepth=4, maxBins=32) #计算在测试集上的错误率并打印随机森林 predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification forest model:') print(model.toDebugString()) 数据结果 Test Error = 0.0434782608696 Learned classification forest model: TreeEnsembleModel classifier with 3 trees Tree 0: If (feature 435 <= 0.0) If (feature 377 <= 103.0) Predict: 0.0 Else (feature 377 > 103.0) Predict: 1.0 Else (feature 435 > 0.0) Predict: 1.0 Tree 1: If (feature 525 <= 0.0) If (feature 461 <= 0.0) If (feature 551 <= 186.0) If (feature 324 <= 3.0) Predict: 0.0 Else (feature 324 > 3.0) Predict: 1.0 Else (feature 551 > 186.0) Predict: 1.0 Else (feature 461 > 0.0) Predict: 1.0 Else (feature 525 > 0.0) Predict: 0.0 Tree 2: If (feature 468 <= 0.0) If (feature 462 <= 0.0) Predict: 0.0 Else (feature 462 > 0.0) Predict: 1.0 Else (feature 468 > 0.0) Predict: 0.0
参考
1、https://www.ibm.com/developerworks/cn/opensource/os-cn-spark-random-forest/index.html
2、http://blog.csdn.net/shenxiaoming77/article/details/58131099
3、http://spark.apache.org/docs/1.6.0/api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest