Open In App

Cross Validation in Machine Learning

Last Updated : 10 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Cross-validation is a technique used to check how well a machine learning model performs on unseen data. It splits the data into several parts, trains the model on some parts and tests it on the remaining part repeating this process multiple times. Finally the results from each validation step are averaged to produce a more accurate estimate of the model's performance.

The main purpose of cross validation is to prevent overfitting. If you want to make sure your machine learning model is not just memorizing the training data but is capable of adapting to real-world data cross-validation is a commonly used technique.

Types of Cross-Validation

There are several types of cross validation techniques which are as follows:

1. Holdout Validation

In Holdout Validation we perform training on the 50% of the given dataset and rest 50% is used for the testing purpose. It's a simple and quick way to evaluate a model. The major drawback of this method is that we perform training on the 50% of the dataset, it may possible that the remaining 50% of the data contains some important information which we are leaving while training our model that can lead to higher bias.

2. LOOCV (Leave One Out Cross Validation)

In this method we perform training on the whole dataset but leaves only one data-point of the available dataset and then iterates for each data-point. In LOOCV the model is trained on n-1 samples and tested on the one omitted sample repeating this process for each data point in the dataset. It has some advantages as well as disadvantages also.

  • An advantage of using this method is that we make use of all data points and hence it is low bias.
  • The major drawback of this method is that it leads to higher variation in the testing model as we are testing against one data point. If the data point is an outlier it can lead to higher variation.
  • Another drawback is it takes a lot of execution time as it iterates over the number of data points we have.

3. Stratified Cross-Validation

It is a technique used in machine learning to ensure that each fold of the cross-validation process maintains the same class distribution as the entire dataset. This is particularly important when dealing with imbalanced datasets where certain classes may be under represented. In this method:

  • The dataset is divided into k folds while maintaining the proportion of classes in each fold.
  • During each iteration, one-fold is used for testing and the remaining folds are used for training.
  • The process is repeated k times with each fold serving as the test set exactly once.

Stratified Cross-Validation is essential when dealing with classification problems where maintaining the balance of class distribution is crucial for the model to generalize well to unseen data.

4. K-Fold Cross Validation

In K-Fold Cross Validation we split the dataset into k number of subsets known as folds then we perform training on the all the subsets but leave one (k-1) subset for the evaluation of the trained model. In this method, we iterate k times with a different subset reserved for testing purpose each time.

Note: It is always suggested that the value of k should be 10 as the lower value of k takes towards validation and higher value of k leads to LOOCV method.

Example of K Fold Cross Validation

The diagram below shows an example of the training subsets and evaluation subsets generated in k-fold cross-validation. Here we have total 25 instances. In first iteration we use the first 20 percent of data for evaluation and the remaining 80 percent for training like [1-5] testing and [5-25] training while in the second iteration we use the second subset of 20 percent for evaluation and the remaining three subsets of the data for training like [5-10] testing and [1-5 and 10-25] training and so on.

IterationTraining Set ObservationsTesting Set Observations
1[5-24][0-4]
2[0-4, 10-24][5-9]
3[0-9, 15-24][10-14]
4[0-14, 20-24][15-19]
5[0-19][20-24]

Each iteration uses different subsets for testing and training, ensuring that all data points are used for both training and testing.

Comparison between K-Fold Cross-Validation and Hold Out Method

K-Fold Cross-Validation and Hold Out Method are widely used technique and sometimes they are confusing so here is the quick comparison between them:

FeatureK-Fold Cross-ValidationHold-Out Method
DefinitionThe dataset is divided into 'k' subsets (folds). Each fold gets a turn to be the test set while the others are used for training.The dataset is split into two sets: one for training and one for testing.
Training SetsThe model is trained 'k' times, each time on a different training subset.The model is trained once on the training set.
Testing SetsThe model is tested 'k' times, each time on a different test subset.The model is tested once on the test set.
BiasLess biased due to multiple splits and testing.Can have higher bias due to a single split.
VarianceLower variance, as it tests on multiple splits.Higher variance, as results depend on the single split.
Computation CostHigh, as the model is trained and tested 'k' times.Low, as the model is trained and tested only once.
Use in Model SelectionBetter for tuning and evaluating model performance due to reduced bias.Less reliable for model selection, as it might give inconsistent results.
Data UtilizationThe entire dataset is used for both training and testing.Only a portion of the data is used for testing, so some data is not used for validation.
Suitability for Small DatasetsPreferred for small datasets, as it maximizes data usage.Less ideal for small datasets, as a significant portion is held out for testing.
Risk of OverfittingLess prone to overfitting due to multiple training and testing cycles.Higher risk of overfitting as the model is trained on one set.

Advantages and Disadvantages of Cross Validation

Advantages:

  1. Overcoming Overfitting: Cross validation helps to prevent overfitting by providing a more robust estimate of the model's performance on unseen data.
  2. Model Selection: Cross validation is used to compare different models and select the one that performs the best on average.
  3. Hyperparameter tuning: This is used to optimize the hyperparameters of a model such as the regularization parameter by selecting the values that result in the best performance on the validation set.
  4. Data Efficient: It allow the use of all the available data for both training and validation making it more data-efficient method compared to traditional validation techniques.

Disadvantages:

  1. Computationally Expensive: It can be computationally expensive especially when the number of folds is large or when the model is complex and requires a long time to train.
  2. Time-Consuming: It can be time-consuming especially when there are many hyperparameters to tune or when multiple models need to be compared.
  3. Bias-Variance Tradeoff: The choice of the number of folds in cross validation can impact the bias-variance tradeoff i.e too few folds may result in high bias while too many folds may result in high variance.

Python implementation for k fold cross-validation

Step 1: Importing necessary libraries

We will import scikit learn.

Python
from sklearn.model_selection import cross_val_score, KFold
from sklearn.svm import SVC
from sklearn.datasets import load_iris

Step 2: Loading the dataset

let's use the iris dataset which is a multi-class classification in-built dataset.

Python
iris = load_iris()
X, y = iris.data, iris.target

Step 3: Creating SVM classifier

SVC is a Support Vector Classification model from scikit-learn.

Python
svm_classifier = SVC(kernel='linear')

Step 4: Defining the number of folds for cross-validation

Here we will be using 5 folds.

Python
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

Step 5: Performing k-fold cross-validation

Python
cross_val_results = cross_val_score(svm_classifier, X, y, cv=kf)

Step 6: Evaluation metrics

Python
print("Cross-Validation Results (Accuracy):")
for i, result in enumerate(cross_val_results, 1):
    print(f"  Fold {i}: {result * 100:.2f}%")
    
print(f'Mean Accuracy: {cross_val_results.mean()* 100:.2f}%')

Output:

Cross-validation-accuracy
Cross validation accuracy

The output shows the accuracy scores from each of the 5 folds in the K-fold cross-validation process. The mean accuracy is the average of these individual scores which is approximately 97.33% indicating the model's overall performance across all the folds.


Next Article
Article Tags :
Practice Tags :

Similar Reads