spark决策树API分析

4,136次阅读
没有评论

共计 2094 个字符,预计需要花费 6 分钟才能阅读完成。

此版本是ml版本,区别于mllib版本的决策树api

输入

Param name Type(s) Default Description
labelCol Double “label” 标签
featuresCol Vector “features” 特征向量

 

输出

Param name Type(s) Default Description Notes
predictionCol Double “prediction” 预测结果标签
rawPredictionCol Vector “rawPrediction” Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction 仅限分类
probabilityCol Vector “probability” Vector of length # classes equal to rawPrediction normalized to a multinomial distribution 仅限分类
varianceCol Double 预测结果方差 仅限回归

API函数

决策树分类器
class pyspark.ml.classification.DecisionTreeClassifier(selffeaturesCol=”features”labelCol=”label”predictionCol=”prediction”probabilityCol=”probability”rawPredictionCol=”rawPrediction”maxDepth=5maxBins=32minInstancesPerNode=1minInfoGain=0.0maxMemoryInMB=256cacheNodeIds=FalsecheckpointInterval=10impurity=”gini”)

maxDepth=5表示树的深度最大为5

maxBins=32表示离散化连续变量分区个数最大值

imputity=”gini”决策树节点特征选择方式,类似还有c45等

代码实例

>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
...     (1.0, Vectors.dense(1.0)),
...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
>>> model.numNodes
3
>>> model.depth
1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
>>> result.probability
DenseVector([1.0, 0.0])
>>> result.rawPrediction
DenseVector([1.0, 0.0])
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0

决策树模型

class pyspark.ml.classification.DecisionTreeClassificationModel(java_model)

模型是通过决策树分类器训练获得

主要的内部成员函数是transform,负责对测试样本预测分类,返回结果是dataframe格式
transform(datasetparams=None)

Transforms the input dataset with optional parameters.

Parameters:
  • dataset – input dataset, which is an instance of pyspark.sql.DataFrame
  • params – an optional param map that overrides embedded params.
Returns:

transformed dataset

New in version 1.3.0.

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2017-08-16发表,共计2094字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码