A Guide to the predict() Function in R

Deepanshu Bhalla Add Comment , , ,

In this article, we will show you how to make predictions in R with different machine learning models. We will cover a variety of models, including decision tree, random forests, logistic regression, support vector machines, and gradient boosting trees.

The following is a list of predict functions for machine learning models in R. These functions generate predicted probabilities based on validation data.

predict Function: Decision Tree

The predict() function in rpart package is used to generate predictions from the previously built decision tree model on the validation dataset.

library(rpart)
tree.pred = predict(tree.model,validation_data, type="prob")
  • tree.model: This is the trained decision tree model.
  • validation_data: This is the validation dataset on which you want to make predictions.
  • type="prob": It specifies that you want to obtain the probabilities as the output instead of just the class labels.

predict Function: Random Forest

The predict() function in randomForest package is used to generate predictions from the previously built random forest model on the validation dataset.

library(randomForest)  
rf.pred = predict(rf.model,validation_data, type="prob")
  • rf.model: This is the trained random forest model.
  • validation_data: This is the validation dataset on which you want to make predictions.
  • type="prob": It specifies that you want to obtain the probabilities as the output instead of just the class labels.

predict Function: Logistic Regression

The predict() function is used to generate predictions from the previously built logistic regression model on the validation dataset.

log.pred = predict(log.model,validation_data, type="response")
  • log.model: This is the trained logistic regression model.
  • validation_data: This is the validation dataset on which you want to make predictions.
  • type="response": It specifies that you want to obtain the probabilities as the output.

predict Function: Support Vector Machine

The predict() function in e1071 package is used to generate predictions from the previously built support vector machine model on the validation dataset.

library(e1071)  
svm.pred = predict(svm.model, validation_data, probability = TRUE)
  • svm.model: This is the trained support Vector Machine model.
  • validation_data: This is the validation dataset on which you want to make predictions.
  • probability = TRUE: It means that you want to see the predicted probabilities as the output.

predict Function: Conditional Inference Tree / Forest

The predict() function in party package is used to generate predictions from the previously built conditional inference tree model on the validation dataset.

library(party)  
ct.pred = predict(ct.model,validation_data)
  • ct.model: This is the trained Conditional Inference Tree Model.
  • validation_data: This is the validation dataset on which you want to make predictions.

predict Function: Gradient Boosting Tree

The predict() function in gbm package is used to generate predictions from the previously built gradient boosting tree model on the validation dataset.

library(gbm)  
gbm.pred = predict(gbm.model, newdata=validation_data, type = "response", n.trees=500)
  • gbm.model: This is the trained gradient boosting tree model.
  • validation_data: This is the validation dataset on which you want to make predictions.
  • type="response": It specifies that you want to obtain the probabilities as the output.
  • n.trees=500: Number of Trees.
In Caret Package, there is only one predict function i.e. predict(log.mod,val, type="prob")
Related Posts
Spread the Word!
Share
About Author:
Deepanshu Bhalla

Deepanshu founded ListenData with a simple objective - Make analytics easy to understand and follow. He has over 10 years of experience in data science. During his tenure, he worked with global clients in various domains like Banking, Insurance, Private Equity, Telecom and HR.

Post Comment 0 Response to "A Guide to the predict() Function in R"
Next → ← Prev