Open In App

Predict() function in R

Last Updated : 15 Apr, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

The predict() function in R is a tool used for making predictions from models By using predict(), we can generate predictions based on the fitted model and new input data. The predict() function in R is used to make predictions based on the model object we create. It can predict both the response variable (i.e., the dependent variable) for a new set of input variables or provide model diagnostics such as residuals, confidence intervals or fitted values.

Syntax

predict(object, newdata, type = "response", se.fit = FALSE, ...)

  • object: The model object (e.g., the result of a lm(), glm() or rpart() function).
  • newdata: The new data for which predictions are required.
  • type: Specifies the type of prediction to return (e.g., response, link, etc.).
  • se.fit: If TRUE, standard errors of the predicted values will be returned.

Here is how we can use predict() in different contexts:

Example 1: Predicting with a Linear Model

Let's start by fitting a simple linear regression model using the lm() function and then making predictions using predict().

R
# Sample data
set.seed(123)
d <- data.frame(
  x = rnorm(100, mean = 10, sd = 5),
  y = rnorm(100, mean = 50, sd = 10)
)
# Fit the linear model
model <- lm(y ~ x, data = d)

n_d <- data.frame(x = c(11, 12, 13))

predictions <- predict(model, newdata = n_d)
print(predictions)

In this case, the predict() function will return the predicted values of y for the new x values (11, 12 and 13).

Output:

1 2 3
48.86703 48.76208 48.65714

Example 2: Predicting with a Logistic Regression Model

let's use logistic regression with the glm() function and make predictions using the predict() function.

R
set.seed(123)
d_logit <- data.frame(
  age = rnorm(100, mean = 30, sd = 5),
  bought = sample(0:1, 100, replace = TRUE)
)
# Fit a logistic regression model
model_logit <- glm(bought ~ age, data = d_logit, family = binomial)

n_d_logit <- data.frame(age = c(28, 32, 35))

pre_logit <- predict(model_logit, newdata = n_d_logit, type = "response")
print(pre_logit)

Here, the type = "response" argument tells predict() to return probabilities for the bought variable

Output:

1 2 3
0.4470355 0.5169052 0.5690249

Example 3: Predicting Class Labels with a Decision Tree

In this example, we will use the rpart() function to create a decision tree and predict class labels for new data.

R
install.packages("rpart")
library(rpart)

data_iris <- iris
# Fit a decision tree model
model_tree <- rpart(Species ~ Sepal.Length + Sepal.Width, data = data_iris)

new_data_tree <- data.frame(Sepal.Length = c(5.1, 6.3, 7.2), Sepal.Width = c(3.5, 3.3, 3.1))

predictions_tree <- predict(model_tree, newdata = new_data_tree, type = "class")
print(predictions_tree)

Here, the type = "class" argument returns the predicted class labels (setosa, versicolor or virginica).

Output:

1 2 3
setosa virginica virginica
Levels: setosa versicolor virginica

Related Article:


Next Article

Similar Reads