K- Fold Cross Validation in Machine Learning
Last Updated :
22 Apr, 2025
K-Fold Cross Validation is a statistical technique to measure the performance of a machine learning model by dividing the dataset into K subsets of equal size (folds). The model is trained on K − 1 folds and tested on the last fold. This process is repeated K times, with each fold being used as the testing set exactly once. The performance of the model is then averaged over all K iterations to provide a robust estimate of its generalization ability.
K - Fold train and test splitWhat is Cross Validation
Cross-validation serves multiple purposes:
- Avoids Overfitting: Ensures that the model does not perform well only on the training data but generalizes to unseen data.
- Provides Robust Evaluation: Averages results over multiple iterations, reducing bias and variance in the performance metrics.
- Efficient Use of Data: Maximizes the utilization of the dataset, especially when the data size is limited.
K-Fold Cross Validation
The process can be broken into the following steps:
1. Split Data into K Folds: Divide the dataset into K subsets or folds.
2. Train-Test Iterations: For each fold:
- Use K −1 folds for training the model.
- Use the remaining fold as the test set to evaluate the model.
3. Aggregate Results: Calculate the performance metric (e.g., accuracy, precision, recall, etc.) for each fold and average the results.
Let 𝐷 be the dataset, split into 𝐾 folds. For each fold 𝑘, the training set is:
D_{\text{train}}^{(k)} = D \setminus F_k
and the test set is:
D_{\text{test}}^{(k)} = F_k
The model's performance is computed as:
\text{Performance} = \frac{1}{K} \sum_{k=1}^{K} \text{Metric}(M_k, F_k)
Choosing the Value of K
The choice of K affects the trade-off between bias and variance:
- Small K (e.g., K =2 or K=5): Faster computation with increased variance in performance estimates.
- Large K (e.g., K =10 or K =n, where n is the size of the dataset): Lower variance but higher computational cost. K =n corresponds to Leave-One-Out Cross Validation (LOOCV).
A standard choice is 10-Fold Cross Validation that is a good trade-off between bias and variance in most situations.
Variants of K-Fold Cross Validation
Implementation of K-Fold Cross Validation
Here’s a Python example of how to implement K-Fold Cross Validation using the scikit-learn library:
Python
from sklearn.model_selection import KFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import numpy as np
data = load_iris()
X, y = data.data, data.target
# K-Fold Cross Validation
k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)
# Initialize the RandomForestClassifier model
model = RandomForestClassifier(random_state=42)
# Perform Cross Validation
scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy')
print(f"Accuracy for each fold: {scores}")
average_accuracy = np.mean(scores)
print(f"Average Accuracy: {average_accuracy:.2f}")
Output:
Accuracy for each fold: [0.96666667 0.96666667 0.93333333 0.96666667]
Average Accuracy: 0.97
Explanation: This code does K-Fold Cross Validation with a RandomForestClassifier on the Iris dataset. It divides the data into 5 folds, trains the model on each fold and checks its accuracy. The accuracy of each fold is printed, as well as the average accuracy over all folds.
Benefits of K-Fold Cross Validation
- Efficient Data Usage: Every data point is used for both training and testing.
- Reliable Estimates: Reduces the likelihood of overfitting or underfitting.
- Applicability: Works well with small datasets or when data collection is expensive.
- Time-Series Data: For time-series data, where the order of observations is crucial, a modification called Time Series Cross Validation is used. The training set consists of all data points up to time t, and the test set includes all data points at time t+1 and beyond.
Limitations of K-Fold Cross Validation
- Computational Cost: Re-training the model k times can be time-consuming, especially for large datasets or complex models.
- Data Leakage Risk: Care must be taken to ensure that no data preprocessing (e.g., scaling, encoding) leaks information from the test set into the training set.
- Not Ideal for Time-Series Data: Sequential dependencies in time-series data may render K-Fold inappropriate without modifications like time-based splits.
Similar Reads
Cross Validation in Machine Learning
Cross-validation is a technique used to check how well a machine learning model performs on unseen data. It splits the data into several parts, trains the model on some parts and tests it on the remaining part repeating this process multiple times. Finally the results from each validation step are a
7 min read
Cross-Validation Using K-Fold With Scikit-Learn
Cross-validation involves repeatedly splitting data into training and testing sets to evaluate the performance of a machine-learning model. One of the most commonly used cross-validation techniques is K-Fold Cross-Validation. In this article, we will explore the implementation of K-Fold Cross-Valida
12 min read
Stratified K Fold Cross Validation
Stratified K-Fold Cross Validation is a technique used for evaluating a model. It is particularly useful for classification problems in which the class labels are not evenly distributed i.e data is imbalanced. It is a enhanced version of K-Fold Cross Validation. Key difference is that it uses strati
3 min read
Regularization in Machine Learning
Regularization is an important technique in machine learning that helps to improve model accuracy by preventing overfitting which happens when a model learns the training data too well including noise and outliers and perform poor on new data. By adding a penalty for complexity it helps simpler mode
7 min read
Voting in Machine Learning
What is Sklearn?Scikit-learn also known as Sklearn is a machine-learning package for Python. The name Sklearn is derived from the SciPy Toolkit. Sklearn is built on NumPy, SciPy, and Matplotlib and has two major implications : Sklearn is very fast and efficient.It often prefers working with arrays.A
9 min read
What is AutoML in Machine Learning?
Automated Machine Learning (automl) addresses the challenge of democratizing machine learning by automating the complex model development process. With applications in various sectors, AutoML aims to make machine learning accessible to those lacking expertise. The article highlights the growing sign
13 min read
Cross-validation on Digits Dataset in Scikit-learn
In this article, we will discuss cross-validation and its use on digit datasets. Further, we will see the code implementation using a digits dataset. What is Cross-Validation?Cross Validation on the Digits Dataset will allow us to choose the best parameters avoiding overfitting over the training dat
5 min read
Understanding Cross Decomposition in Machine Learning
Usually, in real-world datasets, some of the features of the data are highly correlated with each other. Applying normal regression methods to highly correlated data is not an effective way to analyze such data, since multicollinearity makes the estimates highly sensitive to any change in the model.
15 min read
How to Avoid Overfitting in Machine Learning?
Overfitting in machine learning occurs when a model learns the training data too well. In this article, we explore the consequences, causes, and preventive measures for overfitting, aiming to equip practitioners with strategies to enhance the robustness and reliability of their machine-learning mode
8 min read
What is No-Code Machine Learning?
As we know Machine learning is a field in which the data are provided according to the use case of the feature engineering then model selection, model training, and model deployment are done with programming languages like Python and R. For developing the model the person or developer must have the
10 min read