spark决策树API分析

2,876次阅读
没有评论

此版本是ml版本,区别于mllib版本的决策树api

输入

Param name Type(s) Default Description
labelCol Double “label” 标签
featuresCol Vector “features” 特征向量

 

输出

Param name Type(s) Default Description Notes
predictionCol Double “prediction” 预测结果标签
rawPredictionCol Vector “rawPrediction” Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction 仅限分类
probabilityCol Vector “probability” Vector of length # classes equal to rawPrediction normalized to a multinomial distribution 仅限分类
varianceCol Double 预测结果方差 仅限回归

API函数

决策树分类器 class pyspark.ml.classification.DecisionTreeClassifier(selffeaturesCol=”features”labelCol=”label”predictionCol=”prediction”probabilityCol=”probability”rawPredictionCol=”rawPrediction”maxDepth=5maxBins=32minInstancesPerNode=1minInfoGain=0.0maxMemoryInMB=256cacheNodeIds=FalsecheckpointInterval=10impurity=”gini”)

maxDepth=5表示树的深度最大为5

maxBins=32表示离散化连续变量分区个数最大值

imputity=”gini”决策树节点特征选择方式,类似还有c45等

代码实例

<span class="gp">>>> </span><span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span>
<span class="gp">>>> </span><span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span>
<span class="gp">>>> </span><span class="n">df</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="gp">... </span>    <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)),</span>
<span class="gp">... </span>    <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">sparse</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">[],</span> <span class="p">[]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span>
<span class="gp">>>> </span><span class="n">stringIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexed"</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">si_model</span> <span class="o">=</span> <span class="n">stringIndexer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">td</span> <span class="o">=</span> <span class="n">si_model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexed"</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">dt</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">td</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">numNodes</span>
<span class="go">3</span>
<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">depth</span>
<span class="go">1</span>
<span class="gp">>>> </span><span class="n">test0</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([(</span><span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="o">-</span><span class="mf">1.0</span><span class="p">),)],</span> <span class="p">[</span><span class="s">"features"</span><span class="p">])</span>
<span class="gp">>>> </span><span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test0</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()</span>
<span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">prediction</span>
<span class="go">0.0</span>
<span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">probability</span>
<span class="go">DenseVector([1.0, 0.0])</span>
<span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">rawPrediction</span>
<span class="go">DenseVector([1.0, 0.0])</span>
<span class="gp">>>> </span><span class="n">test1</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([(</span><span class="n">Vectors</span><span class="o">.</span><span class="n">sparse</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">]),)],</span> <span class="p">[</span><span class="s">"features"</span><span class="p">])</span>
<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test1</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()</span><span class="o">.</span><span class="n">prediction</span>
<span class="go">1.0</span>

决策树模型

class pyspark.ml.classification.DecisionTreeClassificationModel(java_model)

模型是通过决策树分类器训练获得

主要的内部成员函数是transform,负责对测试样本预测分类,返回结果是dataframe格式 transform(datasetparams=None)

Transforms the input dataset with optional parameters.

Parameters:
  • dataset – input dataset, which is an instance of pyspark.sql.DataFrame
  • params – an optional param map that overrides embedded params.
Returns:

transformed dataset

New in version 1.3.0.

admin
版权声明:本站原创文章,由admin2017-08-16发表,共计2094字。
转载提示:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)