Skip to content

Commit

Permalink
added explanation about the GridSearchCV best_score_ attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jan 20, 2016
1 parent a8cf7b7 commit 5c301d7
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 16 deletions.
169 changes: 156 additions & 13 deletions code/bonus/svm_iris_pipeline_and_gridsearch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
"output_type": "stream",
"text": [
"Sebastian Raschka \n",
"Last updated: 11/30/2015 \n",
"Last updated: 01/20/2016 \n",
"\n",
"CPython 3.5.0\n",
"IPython 4.0.0\n",
"CPython 3.5.1\n",
"IPython 4.0.1\n",
"\n",
"numpy 1.10.1\n",
"pandas 0.17.1\n",
Expand Down Expand Up @@ -77,7 +77,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Done 40 out of 40 | elapsed: 0.1s finished\n"
"[Parallel(n_jobs=-1)]: Done 40 out of 40 | elapsed: 0.2s finished\n"
]
},
{
Expand All @@ -89,7 +89,7 @@
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False))]),\n",
" fit_params={}, iid=True, n_jobs=-1,\n",
" param_grid=[{'svc__C': [1, 10, 100, 1000], 'svc__kernel': ['rbf'], 'svc__gamma': [0.001, 0.0001]}],\n",
" param_grid=[{'svc__kernel': ['rbf'], 'svc__C': [1, 10, 100, 1000], 'svc__gamma': [0.001, 0.0001]}],\n",
" pre_dispatch='2*n_jobs', refit=True, scoring='accuracy', verbose=1)"
]
},
Expand Down Expand Up @@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {
"collapsed": false
},
Expand All @@ -153,7 +153,7 @@
"output_type": "stream",
"text": [
"Best GS Score 0.96\n",
"best GS Params {'svc__C': 100, 'svc__kernel': 'rbf', 'svc__gamma': 0.001}\n",
"best GS Params {'svc__kernel': 'rbf', 'svc__C': 100, 'svc__gamma': 0.001}\n",
"\n",
"Train Accuracy: 0.97\n",
"\n",
Expand All @@ -162,7 +162,6 @@
}
],
"source": [
"\n",
"print('Best GS Score %.2f' % gs.best_score_)\n",
"print('best GS Params %s' % gs.best_params_)\n",
"\n",
Expand All @@ -179,13 +178,157 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
"source": [
"### A Note about `GridSearchCV`'s `best_score_` attribute"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please note that `gs.best_score_` is the average k-fold cross-validation score. I.e., if we have a `GridSearchCV` object with 5-fold cross-validation (like the one above), the `best_score_` attribute returns the average score over the 5-folds of the best model. To illustrate this with an example:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.6, 0.4, 0.6, 0.2, 0.6])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cross_validation import StratifiedKFold, cross_val_score\n",
"from sklearn.linear_model import LogisticRegression\n",
"import numpy as np\n",
"\n",
"np.random.seed(0)\n",
"np.set_printoptions(precision=6)\n",
"y = [np.random.randint(3) for i in range(25)]\n",
"X = (y + np.random.randn(25)).reshape(-1, 1)\n",
"\n",
"cv5_idx = list(StratifiedKFold(y, n_folds=5, shuffle=False, random_state=0))\n",
"cross_val_score(LogisticRegression(random_state=123), X, y, cv=cv5_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By executing the code above, we created a simple data set of random integers that shall represent our class labels. Next, we fed the indices of 5 cross-validation folds (`cv3_idx`) to the `cross_val_score` scorer, which returned 5 accuracy scores -- these are the 5 accuracy values for the 5 test folds. \n",
"\n",
"Next, let us use the `GridSearchCV` object and feed it the same 5 cross-validation sets (via the pre-generated `cv3_idx` indices):"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"[CV] ................................................................\n",
"[CV] ....................................... , score=0.600000 - 0.0s\n",
"[CV] ................................................................\n",
"[CV] ....................................... , score=0.400000 - 0.0s\n",
"[CV] ................................................................\n",
"[CV] ....................................... , score=0.600000 - 0.0s\n",
"[CV] ................................................................\n",
"[CV] ....................................... , score=0.200000 - 0.0s\n",
"[CV] ................................................................\n",
"[CV] ....................................... , score=0.600000 - 0.0s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.0s finished\n"
]
}
],
"source": [
"from sklearn.grid_search import GridSearchCV\n",
"gs = GridSearchCV(LogisticRegression(), {}, cv=cv5_idx, verbose=3).fit(X, y) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the scores for the 5 folds are exactly the same as the ones from `cross_val_score` earlier. \n",
"Now, the best_score_ attribute of the `GridSearchCV` object, which becomes available after `fit`ting, returns the average accuracy score of the best model:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.47999999999999998"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gs.best_score_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the result above is consistent with the average score computed the `cross_val_score`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.47999999999999998"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(LogisticRegression(), X, y, cv=cv5_idx).mean()"
]
}
],
"metadata": {
Expand All @@ -204,7 +347,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.0"
"version": "3.5.1"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 5c301d7

Please sign in to comment.