Introduction to Scikit-learn
Chapter 8: Cross-validation
Cross-validation
When you train your models, there is a risk of overfitting on your test dataset. For example, you might tweak your parameters until you achieve an optimal performance on your test set.
Is that really a good thing? Just because your model achieves superior performance on one test set does not mean that it will perform as well on another test set. You might only happen to have hit the jackpot with your tweaking on this one test set, but may perform miserably on others. It may be better to have your model to be robust enough to generalise to new, unseen data (remember bias-variance trade-off!)
You will learn more about this topic in Week 4 of the Introduction to Machine Learning course when discussing Machine Learning Evaluation, where you will be presented with several ways to deal with this issue.
One of these methods is called K-fold cross-validation. To put it simply:
- You divide your dataset into K non-overlapping subsets.
- You will then perform K separate experiments:
- In each experiment, you keep one of the K subsets as your test data, and use the remaining K-1 as training.
- You will end up with K scores, one for each of your K subset.
In summary, you are essentially testing on K different test datasets to ensure that your model can generalise well.
As you can guess, scikit-learn gives you a function to perform cross-validation without you having to implement it yourself. Imagine having to figure out how to evenly but randomly split your dataset, making sure that they do not overlap, and picking the correct subset inside multiple nested loops! Real fun! (I have actually done this myself when scikit-learn did not quite exist)
Let’s say we want to perform 5-fold cross-validation with our pipeline from the previous page. Here is how you do it with scikit-learn.
1 2 3 4 5 6 7 |
|
pipeline
in Line 3 can also be any classifier (like ourknn_classifier
ordt_classifier
from earlier).results_dict
holds adict
of the results containing some statistics and also the test score for each “fold”.