How to do nested cross-validation with LASSO in caret or tidymodels?
Last Updated :
24 Apr, 2025
Nested cross-validation is a robust technique used for hyperparameter tuning and model selection. When working with complex models like LASSO (Least Absolute Shrinkage and Selection Operator), it becomes essential to understand how to implement nested cross-validation efficiently. In this article, we'll explore the concept of nested cross-validation and how to implement it with LASSO using popular R packages, Caret and Tidymodels.
Understanding Nested Cross-Validation
Nested cross-validation is a technique for evaluating and tuning machine learning models that helps prevent overfitting and provides a more realistic estimate of a model's performance on unseen data. It consists of two levels of cross-validation:
- Outer Loop: This loop divides the dataset into training and testing sets. It helps estimate the model's performance on independent data splits.
- Inner Loop: Inside each outer fold, another cross-validation loop is used to select the best hyperparameters for the model.
Nested cross-validation is particularly useful when you have a limited dataset or need to optimize hyperparameters to ensure model generalization.
LASSO Regression
A regularisation method is lasso regression. For a more accurate forecast, it is preferred over regression techniques. Shrinkage is used in this model. When data values shrink towards the mean, this is referred to as shrinkage. Models with fewer parameters are encouraged by the lasso technique since they are straightforward and sparse. When a model exhibits a high degree of multicollinearity or when you wish to automate some steps in the model selection process, such as variable selection and parameter removal, this specific sort of regression is ideally suited.
L1 regularisation technique is used in Lasso Regression (explained more in this article). Because it does feature selection automatically, it is employed when there are more features.
Here’s a step-by-step explanation of how LASSO regression works:
- Linear Regression Model: A linear relationship between the independent variables (features) and the dependent variable (target) is assumed in the linear regression model, which is where LASSO regression begins.
- L1 regularisation: LASSO regression adds a second penalty term depending on the absolute coefficient values. The absolute values of the coefficients are added together and then multiplied by a tuning parameter to create the L1 regularisation term.
- Objective function: Finding the values of the coefficients that minimise the sum of the squared differences between the predicted values and the actual values while also minimising the L1 regularisation term is the goal of LASSO regression.
- Shrinking Coefficients: Coefficients can be shrunk towards zero in LASSO regression by including the L1 regularisation component. Some coefficients are driven to exactly zero when is large enough. The variables with zero coefficients are efficiently eliminated from the model thanks to this characteristic of LASSO, which makes it beneficial for feature selection.
- Tuning Parameter: In LASSO regression, the regularisation parameter's selection is critical. More coefficients will be pushed towards zero as regularisation increases with increasing values. A smaller value, on the other hand, lessens the regularisation impact, allowing more variables to have coefficients that are not zero.
- Model fitting: An optimisation approach is employed to minimise the objective function in order to estimate the coefficients in the LASSO regression. It is usual practise to use Coordinate Descent, which fixes the other coefficients while iteratively updating each coefficient.
Why Use LASSO?
LASSO is a linear regression technique that adds a penalty term to the linear regression cost function. This penalty encourages the model to shrink some coefficients to exactly zero, effectively performing feature selection. LASSO is valuable when dealing with datasets with many features or when you suspect that some features are irrelevant.
Pre-Requisites
Before diving into nested cross-validation, make sure you have R installed along with the Caret and Tidymodels packages. You can install them using the following commands:
R
install.packages("caret")
install.packages("tidymodels")
install.packages("mlbench")
- caret : caret enables you to train different types of algorithms using a simple train function
- tidymodels : Tidymodels for modeling and statistical analysis that shares the underlying design philosophy, grammar, and data structures of the tidyverse.
- mlbench : mlbench is a collection of artificial and real-world machine learning benchmark problems, including, e.g., several data sets from the UCI repository.
Loading Libraries
R
library(caret)
library(tidymodels)
library(mlbench)
Load the dataset
R
# Load a built-in dataset from the mlbench package
data(Sonar)
The data report the patterns obtained by bouncing sonar signals at various angles and under various conditions. There are 208 patterns in all, 111 obtained by bouncing sonar signals off a metal cylinder and 97 obtained by bouncing signals off rocks. Each pattern is a set of 60 numbers (variables) taking values between 0 and 1.
Implementing Nested Cross-Validation with Caret
Here's how you can perform nested cross-validation with LASSO using Caret:
R
# Define your control parameters for outer CV
ctrl <- trainControl(
method = "cv",
number = 5,
summaryFunction = twoClassSummary,
classProbs = TRUE,
search = "grid"
)
# Define a hyperparameter grid for LASSO (aplha = 1)
grid <- expand.grid(
alpha = 1,
lambda = seq(0.001, 1, length = 10)
)
# Perform nested cross-validation
set.seed(123)
model <- train(
Class ~ .,
data = Sonar,
method = "glmnet",
trControl = ctrl,
tuneGrid = grid
)
# Print the best hyperparameters
print(model$bestTune)
Output:
alpha lambda
1 1 0.001
- Loading Data: Initially, you load the "Sonar" dataset sourced from the "mlbench" package. This dataset serves as the foundation for a classification task.
- Control Parameters for Outer CV: You establish control parameters for the outer cross-validation loop, specifying the methodology as "cv," denoting 5-fold cross-validation rounds. The chosen twoClassSummary function computes performance metrics for binary classification, including class probabilities. Additionally, you opt for grid search as the search method.
- Hyperparameter Grid for LASSO: You define a hyperparameter grid to fine-tune the LASSO model. The alpha value is set to 1, signifying pure LASSO regularization, while the lambda range spans from 0.001 to 1 in increments, regulating the degree of coefficient shrinkage.
- Perform Nested Cross-Validation: Employing the "caret" package's train function, you execute nested cross-validation. Your classification task, denoted as "Class ~ .," employs all available variables as features. The "glmnet" method, representing LASSO logistic regression, is employed. You supply the pre-specified control parameters (trControl) for cross-validation and the hyperparameter grid (tuneGrid) for fine-tuning.
- Printing the Best Hyperparameters: Concluding the process, you display the optimal hyperparameters derived from the nested cross-validation procedure, encompassing the most favorable alpha and lambda values for optimizing model performance.
This code essentially performs nested cross-validation to find the best hyperparameters for a LASSO logistic regression model, using predefined control parameters and a hyperparameter grid. The goal is to identify the hyperparameters that yield the best classification performance.
Nested Cross-Validation with Tidymodels on mtcars Dataset
Tidymodels is another popular package for modeling and machine learning. Here's how to perform nested cross-validation with LASSO using Tidymodels:
Loading Libraries
We start by loading the necessary libraries, primarily tidymodels, which provides a framework for modeling and machine learning.
R
# Load necessary libraries
library(tidymodels)
Loading the Dataset
R
- Load Dataset: The code loads the mtcars dataset, which is a built-in dataset in R. This dataset contains information about various car models, including attributes like miles per gallon (mpg), horsepower, and more.
Summary of the Dataset
R
Output:
mpg cyl disp hp drat wt qsec
Min. :10.40 Min. :4.000 Min. : 71.1 Min. : 52.0 Min. :2.760 Min. :1.513 Min. :14.50
1st Qu.:15.43 1st Qu.:4.000 1st Qu.:120.8 1st Qu.: 96.5 1st Qu.:3.080 1st Qu.:2.581 1st Qu.:16.89
Median :19.20 Median :6.000 Median :196.3 Median :123.0 Median :3.695 Median :3.325 Median :17.71
Mean :20.09 Mean :6.188 Mean :230.7 Mean :146.7 Mean :3.597 Mean :3.217 Mean :17.85
3rd Qu.:22.80 3rd Qu.:8.000 3rd Qu.:326.0 3rd Qu.:180.0 3rd Qu.:3.920 3rd Qu.:3.610 3rd Qu.:18.90
Max. :33.90 Max. :8.000 Max. :472.0 Max. :335.0 Max. :4.930 Max. :5.424 Max. :22.90
vs am gear carb
Min. :0.0000 Min. :0.0000 Min. :3.000 Min. :1.000
1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.:2.000
Median :0.0000 Median :0.0000 Median :4.000 Median :2.000
Mean :0.4375 Mean :0.4062 Mean :3.688 Mean :2.812
3rd Qu.:1.0000 3rd Qu.:1.0000 3rd Qu.:4.000 3rd Qu.:4.000
Max. :1.0000 Max. :1.0000 Max. :5.000 Max. :8.000
Pre-processing the dataset
R
# Define your target variable (replace 'YOUR_TARGET_VARIABLE' with your actual target variable)
target_var <- "mpg"
# Define your recipe for preprocessing
preprocess <- recipe(as.formula(paste(target_var, "~ .")), data = mtcars) %>%
step_normalize(all_predictors())
- Define Target Variable: Here, you would replace 'YOUR_TARGET_VARIABLE' with the actual name of the variable you want to predict. In this example, the target variable is set to "mpg", which represents fuel efficiency in miles per gallon.
- Preprocessing Recipe: A recipe for preprocessing is defined using recipe. It specifies how the data should be transformed before modeling. In this case, the data is normalized using step_normalize, which scales the numeric predictors to have a mean of 0 and a standard deviation of 1.
Creating Lasso object and workflow
R
# Create a model specification for Lasso regression
lasso <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
# Create a workflow that includes the recipe and model specification
lasso_workflow <- workflow() %>%
add_recipe(preprocess) %>%
add_model(lasso)
- Model Specification: A model specification for Lasso regression is created using linear_reg. Lasso regression is a linear regression variant that includes L1 regularization to prevent overfitting.
- Workflow: A workflow is set up to combine the preprocessing recipe and model specification. This allows for a streamlined and consistent process for training and evaluating models.
Implementing Nested Cross-Validation with Tidymodels
R
# Define a grid of hyperparameters to search
hypergrid <- expand.grid(
penalty = seq(0.001, 1, length.out = 10)
)
# Set up a cross-validation resampling method
# 10-fold cross-validation
cv <- vfold_cv(mtcars, v = 10)
# Tune the Lasso model using cross-validation
lasso_t <- tune_grid(
object = lasso_workflow,
resamples = cv,
grid = hypergrid,
metrics = metric_set(rmse, mae, rsq)
)
# Get the best Lasso model
lasso_best <- select_best(lasso_t, "rmse")
# Print the best Lasso model
lasso_best
Output:
A tibble: 1 × 2
penalty .config
<dbl> <chr>
0.667 Preprocessor1_Model07
- Hyperparameter Grid: A grid of hyperparameters is defined. Hyperparameters are settings of the model that can be tuned to find the best combination. In this case, it's a grid of penalty values for Lasso regularization.
- Cross-Validation: Cross-validation is set up using vfold_cv. It specifies a 10-fold cross-validation strategy, which divides the data into 10 subsets for training and testing.
- Tune the Model: The tune_grid function is used to tune the Lasso model. It searches through the hyperparameter grid and evaluates the model's performance using metrics like RMSE (Root Mean Squared Error), MAE (Mean Absolute Error), and R-squared. This helps find the best hyperparameters for the model.
- Select Best Model: The select_best function identifies the best Lasso model based on the specified metric (in this case, RMSE). The best model is selected for further analysis.
- Print the Best Model: The code prints the details of the best Lasso model, including its hyperparameters and performance metrics.
Conclusion
Implementing nested cross-validation with LASSO in R using Caret or Tidymodels can help you build robust and accurate predictive models. It ensures that your model generalizes well to unseen data and helps you select the best hyperparameters for your LASSO model.
By following the steps outlined in this article, you can confidently apply nested cross-validation to your machine learning projects, ultimately improving the reliability and performance of your models.
Similar Reads
How to perform 10 fold cross validation with LibSVM in R?
Support Vector Machines (SVM) are a powerful tool for classification and regression tasks. LibSVM is a widely used library that implements SVM, and it can be accessed in R with the e1071 package. Cross-validation, particularly 10-fold cross-validation, is an essential technique for assessing the per
4 min read
How to Validate Nested Objects with Class-Validator in Nest.js?
Validating nested objects in Nest.js using the class validator is crucial for ensuring the integrity and structure of incoming data, especially when dealing with complex data models. Class validator works seamlessly with Nest.js to validate incoming request data using Data Transfer Objects (DTOs). T
4 min read
Cross validation in R without caret package
Cross-validation is a technique for evaluating the performance of a machine learning model by training it on a subset of the data and evaluating it on the remaining data. It is a useful method for estimating the performance of a model when you don't have a separate test set, or when you want to get
4 min read
Cross Validation on a Dataset with Factors in R
Cross-validation is a widely used technique in machine learning and statistical modeling to assess how well a model generalizes to new data. When working with datasets containing factors (categorical variables), it's essential to handle them appropriately during cross-validation to ensure unbiased p
4 min read
How to Deal with Factors with Rare Levels in Cross-Validation in R
Cross-validation is a vital technique for evaluating model performance in machine learning. However, traditional cross-validation approaches may lead to biased or unreliable results when dealing with factors (categorical variables) that contain rare levels. In this guide, we'll explore strategies fo
4 min read
Receiver Operating Characteristic (ROC) with Cross Validation in Scikit Learn
In this article, we will implement ROC with Cross-Validation in Scikit Learn. Before we jump into the code, let's first understand why we need ROC curve and Cross-Validation in Machine Learning model predictions. Receiver Operating Characteristic Curve (ROC Curve) To understand the ROC curve one mu
3 min read
Creating Custom Cross-Validation Generators in Scikit-learn
Cross-validation is a fundamental technique in machine learning used to assess the performance and generalizability of models. Scikit-learn, a popular Python library, provides several built-in cross-validation methods, such as K-Fold, Stratified K-Fold, and Time Series Split. However, there are scen
6 min read
Cross-Validation Using K-Fold With Scikit-Learn
Cross-validation involves repeatedly splitting data into training and testing sets to evaluate the performance of a machine-learning model. One of the most commonly used cross-validation techniques is K-Fold Cross-Validation. In this article, we will explore the implementation of K-Fold Cross-Valida
12 min read
Cross Validation function for logistic regression in R
Cross-validation is a technique for assessing the performance of a machine-learning model. It helps in understanding how the model generalizes to an independent dataset, thereby ensuring that the model is neither overfitted nor underfitted. This article will guide you through creating a cross-valida
3 min read
Recursive Feature Elimination with Cross-Validation in Scikit Learn
In this article, we will earn how to implement recursive feature elimination with cross-validation using scikit learn package in Python. What is Recursive Feature Elimination (RFE)? Recursive Feature Elimination (RFE) is a feature selection algorithm that is used to select a subset of the most relev
5 min read