机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

14,614次阅读
没有评论

网格搜索算法K折交叉验证法机器学习入门的时候遇到的重要的概念。

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

机器学习gridsearchcv(网格搜索)和kfold
训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

机器学习gridsearchcv(网格搜索)和kfold
对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型

下面提供一个简单的利用决策树预测乳腺癌的例子:

<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span><span class="p">,</span> <span class="n">KFold</span><span class="p">,</span> <span class="n">train_test_split</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">make_scorer</span><span class="p">,</span> <span class="n">accuracy_score</span>
<span class="kn">from</span> <span class="nn">sklearn.tree</span> <span class="kn">import</span> <span class="n">DecisionTreeClassifier</span>
<span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">load_breast_cancer</span>

<span class="n">data</span> <span class="o">=</span> <span class="n">load_breast_cancer</span><span class="p">()</span>

<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
    <span class="n">data</span><span class="p">[</span><span class="s1">'data'</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="s1">'target'</span><span class="p">],</span> <span class="n">train_size</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

<span class="n">regressor</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'max_depth'</span><span class="p">:</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">)}</span>
<span class="n">scoring_fnc</span> <span class="o">=</span> <span class="n">make_scorer</span><span class="p">(</span><span class="n">accuracy_score</span><span class="p">)</span>
<span class="n">kfold</span> <span class="o">=</span> <span class="n">KFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>

<span class="n">grid</span> <span class="o">=</span> <span class="n">GridSearchCV</span><span class="p">(</span><span class="n">regressor</span><span class="p">,</span> <span class="n">parameters</span><span class="p">,</span> <span class="n">scoring_fnc</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">kfold</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">reg</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">best_estimator_</span>

<span class="k">print</span><span class="p">(</span><span class="s1">'best score: </span><span class="si">%f</span><span class="s1">'</span><span class="o">%</span><span class="n">grid</span><span class="o">.</span><span class="n">best_score_</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'best parameters:'</span><span class="p">)</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="s1">'</span><span class="si">%s</span><span class="s1">: </span><span class="si">%d</span><span class="s1">'</span><span class="o">%</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">reg</span><span class="o">.</span><span class="n">get_params</span><span class="p">()[</span><span class="n">key</span><span class="p">]))</span>

<span class="k">print</span><span class="p">(</span><span class="s1">'test score: </span><span class="si">%f</span><span class="s1">'</span><span class="o">%</span><span class="n">reg</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span>

<span class="kn">import</span> <span class="nn">pandas</span> <span class="kn">as</span> <span class="nn">pd</span>
<span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">grid</span><span class="o">.</span><span class="n">cv_results_</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462
best parameters:
max_depth: 4
test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

admin
版权声明:本站原创文章,由admin2017-12-21发表,共计2033字。
转载提示:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)