Cross Validation in Machine Learning
Last Updated :
10 May, 2025
Cross-validation is a technique used to check how well a machine learning model performs on unseen data. It splits the data into several parts, trains the model on some parts and tests it on the remaining part repeating this process multiple times. Finally the results from each validation step are averaged to produce a more accurate estimate of the model's performance.
The main purpose of cross validation is to prevent overfitting. If you want to make sure your machine learning model is not just memorizing the training data but is capable of adapting to real-world data cross-validation is a commonly used technique.
Types of Cross-Validation
There are several types of cross validation techniques which are as follows:
1. Holdout Validation
In Holdout Validation we perform training on the 50% of the given dataset and rest 50% is used for the testing purpose. It's a simple and quick way to evaluate a model. The major drawback of this method is that we perform training on the 50% of the dataset, it may possible that the remaining 50% of the data contains some important information which we are leaving while training our model that can lead to higher bias.
2. LOOCV (Leave One Out Cross Validation)
In this method we perform training on the whole dataset but leaves only one data-point of the available dataset and then iterates for each data-point. In LOOCV the model is trained on n-1 samples and tested on the one omitted sample repeating this process for each data point in the dataset. It has some advantages as well as disadvantages also.
- An advantage of using this method is that we make use of all data points and hence it is low bias.
- The major drawback of this method is that it leads to higher variation in the testing model as we are testing against one data point. If the data point is an outlier it can lead to higher variation.
- Another drawback is it takes a lot of execution time as it iterates over the number of data points we have.
3. Stratified Cross-Validation
It is a technique used in machine learning to ensure that each fold of the cross-validation process maintains the same class distribution as the entire dataset. This is particularly important when dealing with imbalanced datasets where certain classes may be under represented. In this method:
- The dataset is divided into k folds while maintaining the proportion of classes in each fold.
- During each iteration, one-fold is used for testing and the remaining folds are used for training.
- The process is repeated k times with each fold serving as the test set exactly once.
Stratified Cross-Validation is essential when dealing with classification problems where maintaining the balance of class distribution is crucial for the model to generalize well to unseen data.
4. K-Fold Cross Validation
In K-Fold Cross Validation we split the dataset into k number of subsets known as folds then we perform training on the all the subsets but leave one (k-1) subset for the evaluation of the trained model. In this method, we iterate k times with a different subset reserved for testing purpose each time.
Note: It is always suggested that the value of k should be 10 as the lower value of k takes towards validation and higher value of k leads to LOOCV method.
Example of K Fold Cross Validation
The diagram below shows an example of the training subsets and evaluation subsets generated in k-fold cross-validation. Here we have total 25 instances. In first iteration we use the first 20 percent of data for evaluation and the remaining 80 percent for training like [1-5] testing and [5-25] training while in the second iteration we use the second subset of 20 percent for evaluation and the remaining three subsets of the data for training like [5-10] testing and [1-5 and 10-25] training and so on. 
Iteration | Training Set Observations | Testing Set Observations |
---|
1 | [5-24] | [0-4] |
---|
2 | [0-4, 10-24] | [5-9] |
---|
3 | [0-9, 15-24] | [10-14] |
---|
4 | [0-14, 20-24] | [15-19] |
---|
5 | [0-19] | [20-24] |
---|
Each iteration uses different subsets for testing and training, ensuring that all data points are used for both training and testing.
Comparison between K-Fold Cross-Validation and Hold Out Method
K-Fold Cross-Validation and Hold Out Method are widely used technique and sometimes they are confusing so here is the quick comparison between them:
Feature | K-Fold Cross-Validation | Hold-Out Method |
---|
Definition | The dataset is divided into 'k' subsets (folds). Each fold gets a turn to be the test set while the others are used for training. | The dataset is split into two sets: one for training and one for testing. |
---|
Training Sets | The model is trained 'k' times, each time on a different training subset. | The model is trained once on the training set. |
---|
Testing Sets | The model is tested 'k' times, each time on a different test subset. | The model is tested once on the test set. |
---|
Bias | Less biased due to multiple splits and testing. | Can have higher bias due to a single split. |
---|
Variance | Lower variance, as it tests on multiple splits. | Higher variance, as results depend on the single split. |
---|
Computation Cost | High, as the model is trained and tested 'k' times. | Low, as the model is trained and tested only once. |
---|
Use in Model Selection | Better for tuning and evaluating model performance due to reduced bias. | Less reliable for model selection, as it might give inconsistent results. |
---|
Data Utilization | The entire dataset is used for both training and testing. | Only a portion of the data is used for testing, so some data is not used for validation. |
---|
Suitability for Small Datasets | Preferred for small datasets, as it maximizes data usage. | Less ideal for small datasets, as a significant portion is held out for testing. |
---|
Risk of Overfitting | Less prone to overfitting due to multiple training and testing cycles. | Higher risk of overfitting as the model is trained on one set. |
---|
Advantages and Disadvantages of Cross Validation
Advantages:
- Overcoming Overfitting: Cross validation helps to prevent overfitting by providing a more robust estimate of the model's performance on unseen data.
- Model Selection: Cross validation is used to compare different models and select the one that performs the best on average.
- Hyperparameter tuning: This is used to optimize the hyperparameters of a model such as the regularization parameter by selecting the values that result in the best performance on the validation set.
- Data Efficient: It allow the use of all the available data for both training and validation making it more data-efficient method compared to traditional validation techniques.
Disadvantages:
- Computationally Expensive: It can be computationally expensive especially when the number of folds is large or when the model is complex and requires a long time to train.
- Time-Consuming: It can be time-consuming especially when there are many hyperparameters to tune or when multiple models need to be compared.
- Bias-Variance Tradeoff: The choice of the number of folds in cross validation can impact the bias-variance tradeoff i.e too few folds may result in high bias while too many folds may result in high variance.
Python implementation for k fold cross-validation
Step 1: Importing necessary libraries
We will import scikit learn.
Python
from sklearn.model_selection import cross_val_score, KFold
from sklearn.svm import SVC
from sklearn.datasets import load_iris
Step 2: Loading the dataset
let's use the iris dataset which is a multi-class classification in-built dataset.
Python
iris = load_iris()
X, y = iris.data, iris.target
Step 3: Creating SVM classifier
SVC
is a Support Vector Classification model from scikit-learn.
Python
svm_classifier = SVC(kernel='linear')
Step 4: Defining the number of folds for cross-validation
Here we will be using 5 folds.
Python
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
Python
cross_val_results = cross_val_score(svm_classifier, X, y, cv=kf)
Step 6: Evaluation metrics
Python
print("Cross-Validation Results (Accuracy):")
for i, result in enumerate(cross_val_results, 1):
print(f" Fold {i}: {result * 100:.2f}%")
print(f'Mean Accuracy: {cross_val_results.mean()* 100:.2f}%')
Output:
Cross validation accuracyThe output shows the accuracy scores from each of the 5 folds in the K-fold cross-validation process. The mean accuracy is the average of these individual scores which is approximately 97.33% indicating the model's overall performance across all the folds.
Similar Reads
K- Fold Cross Validation in Machine Learning
K-Fold Cross Validation is a statistical technique to measure the performance of a machine learning model by dividing the dataset into K subsets of equal size (folds). The model is trained on K â 1 folds and tested on the last fold. This process is repeated K times, with each fold being used as the
4 min read
Regularization in Machine Learning
Regularization is an important technique in machine learning that helps to improve model accuracy by preventing overfitting which happens when a model learns the training data too well including noise and outliers and perform poor on new data. By adding a penalty for complexity it helps simpler mode
7 min read
Optimization Algorithms in Machine Learning
Optimization algorithms are the backbone of machine learning models as they enable the modeling process to learn from a given data set. These algorithms are used in order to find the minimum or maximum of an objective function which in machine learning context stands for error or loss. In this artic
15+ min read
What is No-Code Machine Learning?
As we know Machine learning is a field in which the data are provided according to the use case of the feature engineering then model selection, model training, and model deployment are done with programming languages like Python and R. For developing the model the person or developer must have the
10 min read
What is Machine Learning?
Machine learning is a branch of artificial intelligence that enables algorithms to uncover hidden patterns within datasets. It allows them to predict new, similar data without explicit programming for each task. Machine learning finds applications in diverse fields such as image and speech recogniti
9 min read
Regression in machine learning
Regression in machine learning refers to a supervised learning technique where the goal is to predict a continuous numerical value based on one or more independent features. It finds relationships between variables so that predictions can be made. we have two types of variables present in regression
5 min read
Prediction Intervals for Machine Learning
Prediction intervals are an essential concept in machine learning and statistics, providing a range within which a future observation is expected to fall with a certain probability. Unlike confidence intervals, which estimate the uncertainty of a population parameter, prediction intervals focus on t
9 min read
Statistics For Machine Learning
Machine Learning Statistics: In the field of machine learning (ML), statistics plays a pivotal role in extracting meaningful insights from data to make informed decisions. Statistics provides the foundation upon which various ML algorithms are built, enabling the analysis, interpretation, and predic
7 min read
What is AutoML in Machine Learning?
Automated Machine Learning (automl) addresses the challenge of democratizing machine learning by automating the complex model development process. With applications in various sectors, AutoML aims to make machine learning accessible to those lacking expertise. The article highlights the growing sign
13 min read
Partial derivatives in Machine Learning
Partial derivatives play a vital role in the area of machine learning, notably in optimization methods like gradient descent. These derivatives help us grasp how a function changes considering its input variables. In machine learning, where we commonly deal with complicated models and high-dimension
4 min read