此版本是ml版本,区别于mllib版本的决策树api
输入
Param name | Type(s) | Default | Description |
---|---|---|---|
labelCol | Double | “label” | 标签 |
featuresCol | Vector | “features” | 特征向量 |
输出
Param name | Type(s) | Default | Description | Notes |
---|---|---|---|---|
predictionCol | Double | “prediction” | 预测结果标签 | |
rawPredictionCol | Vector | “rawPrediction” | Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction | 仅限分类 |
probabilityCol | Vector | “probability” | Vector of length # classes equal to rawPrediction normalized to a multinomial distribution | 仅限分类 |
varianceCol | Double | 预测结果方差 | 仅限回归 |
API函数
决策树分类器 class pyspark.ml.classification.DecisionTreeClassifier(self, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, probabilityCol=”probability”, rawPredictionCol=”rawPrediction”, maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity=”gini”)
maxDepth=5表示树的深度最大为5
maxBins=32表示离散化连续变量分区个数最大值
imputity=”gini”决策树节点特征选择方式,类似还有c45等
代码实例
<span class="gp">>>> </span><span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span> <span class="gp">>>> </span><span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span> <span class="gp">>>> </span><span class="n">df</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> <span class="gp">... </span> <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)),</span> <span class="gp">... </span> <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">sparse</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">[],</span> <span class="p">[]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span> <span class="gp">>>> </span><span class="n">stringIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexed"</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">si_model</span> <span class="o">=</span> <span class="n">stringIndexer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">td</span> <span class="o">=</span> <span class="n">si_model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexed"</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">dt</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">td</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">numNodes</span> <span class="go">3</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">depth</span> <span class="go">1</span> <span class="gp">>>> </span><span class="n">test0</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([(</span><span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="o">-</span><span class="mf">1.0</span><span class="p">),)],</span> <span class="p">[</span><span class="s">"features"</span><span class="p">])</span> <span class="gp">>>> </span><span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test0</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">prediction</span> <span class="go">0.0</span> <span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">probability</span> <span class="go">DenseVector([1.0, 0.0])</span> <span class="gp">>>> </span><span class="n">result</span><span class="o">.</span><span class="n">rawPrediction</span> <span class="go">DenseVector([1.0, 0.0])</span> <span class="gp">>>> </span><span class="n">test1</span> <span class="o">=</span> <span class="n">sqlContext</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([(</span><span class="n">Vectors</span><span class="o">.</span><span class="n">sparse</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">]),)],</span> <span class="p">[</span><span class="s">"features"</span><span class="p">])</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test1</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()</span><span class="o">.</span><span class="n">prediction</span> <span class="go">1.0</span>
决策树模型
class pyspark.ml.classification.DecisionTreeClassificationModel(java_model)
模型是通过决策树分类器训练获得
主要的内部成员函数是transform,负责对测试样本预测分类,返回结果是dataframe格式 transform(dataset, params=None)
Transforms the input dataset with optional parameters.
Parameters: |
|
---|---|
Returns: |
transformed dataset |
New in version 1.3.0.