Introductory Guide to Decision Trees and Random Forests

Arshiya Ansari, Chirag Kulkarni, Ronak Rijhwani, Surbhi Singh

Introduction

Before we dive into the machine learning applications of decision trees, I want to demonstrate that the concept of a decision tree is not foreign to you. In fact, you probably mentally construct a number of decision trees every day of your life. Let’s consider an example. You’re on a first date and trying to gauge whether or not you should ask her on a second date to your favorite restaurant the following weekend. Well, it probably depends on whether or not your date enjoyed the first date. If the answer is no, the second date isn’t happening (sorry to break it to you). If your date did indeed enjoy the first date, whether or not you end up on a second date might depend on if your date is available the following weekend. If your date is available, you should ask her on a second date to your favorite restaurant next weekend; otherwise, you should not.

While this is a simplistic example, it is a great representation of a decision tree algorithm — a model of sequential and hierarchical decisions that produce an ultimate outcome.

Motivation

Decision trees are used to visualize and represent decision-making processes, such as the one in the example above. In more common use cases, machine learning techniques similarly leverage these decision tree algorithms. Machine learning applications of decision trees will be the focus of this article.

In general, three primary techniques are employed to extract knowledge from data:

  1. Classification — splitting up data into predefined classes and classifying new data
  2. Regression — producing predictions for future data based on past data
  3. Clustering — finding potential categorizations of the data

While classification and regression are considered supervised learning methods because you provide labeled data for the algorithms, clustering attempts to find classes to put the data in without any labels. Decision trees can be utilized for each of these three exercises, but this article will explore classification and regression trees.

Decision trees as a machine learning model for classification and regression can produce high accuracy while maintaining interpretability. Further, the clarity of information representation makes decision trees unique in the context of other machine learning solutions. As shown in the initial example, the hierarchical structure of the information learned by a decision tree makes it easily understandable. We’ll introduce a few other advantages as well as drawbacks of decision trees.

Advantages:

  • Intuitive and easily understandable
  • Data preparation and pre-processing for decision trees is often light
  • Normalization and scaling of data is not necessary
  • Missing values do not usually impact the process of building a decision tree

Disadvantages:

  • Calculations in decision trees take up more memory and time
  • Reproducibility is sensitive to minor changes in the data and can significantly change the tree structure
  • Greater space and time complexity, leading to greater training time
  • A single decision tree is often insufficient; for instance, many decision trees are usually required for random forest

Decision Tree

A decision tree is a super-powerful machine learning technique that can be used for both classifications (whether an email is classified as spam or not) and regression (where machines predict a numerical value) problems. The reason decision trees are such a remarkable tool for machine learning is that they mimic the way humans think. As the name entails, a decision tree is a tree-like structure flow chart where each node in the tree visually represents a feature in the model, each branch represents a decision that is made and each leaf represents an outcome (categorical or quantitative).

Source: https://www.numpyninja.com/post/decision-trees-example-in-machine-learning

With decision trees, we attempt to answer an overlying question such as in the introduction, we attempted to answer how our first date went. Edges from that node represent the answers to the question and the leaves represent the output.

Starting from the root node, a feature is evaluated and one of the two nodes (branches) is selected, based on an attribute selection measurement. This procedure is repeated until a final leaf is reached, which normally represents the answer to the problem.

There are various attribute selection measurements, and in this article, we will highlight the most popular: Information Gain, Entropy, and Gini Index.

Information Gain

Information Gain is an attribute selection measurement that is popularly used in selecting optimal decisions in a decision tree. It measures how well a given feature separates the data according to the target classification. For example, the image below represents three different nodes whose makeup is completely composed of only two unique data values.

Source: https://technovert.com/what-is-decision-tree/

Node A definitely has a higher combination of both yellow and blue values, making it the most impure sample while Node A has only blue values, making it the purest sample. So given the situation of splitting these based on a decision, we know that to accurately split node A we would require a lot more information than if we were to accurately split node C. Thus, we can conclude that most impure nodes require more information to make a sound decision and less impure nodes require less information to make a sound decision. In context with our decision trees, a feature that has a high information gain will split the data into a relatively uneven distribution of the yellow and blues while a feature that has a low information gain will split the data into a relatively even distribution. A relatively even distribution will not bring us closer to achieving a decision. But to actually calculate the information gain, we must first calculate the entropy of the dataset.

Entropy

Entropy measures the impurity or randomness of the input set or group of examples, so it is an essential component in calculating the information gain. Entropy takes on a value from 0 to 1. For a node like node A where there is an even split between yellow and blue values, the entropy would be 1 whereas the entropy for node C where all the values are blue would be 0. An uneven split of values, such as those present in node B, would take on a value between 0 and 1.

In the equation above, p is the probability of success and q is the probability of failure in the node (represented by different classes). The lowest entropy is chosen as the next split in the decision tree.

Now that we know how to calculate the entropy, we can use it to calculate the information gain, which is represented by the difference between entropy before a split and the average entropy of all the children after the split.

Gini Index

Last but not least is another very common attribute selection measurement, known as the Gini Index (or Gini impurity). Most commonly, classification and regression trees (CART) use the Gini index as a method to create splits. The Gini Index calculates the probability of a specific attribute being calculated incorrectly if selected randomly. The Gini Index is found by (1 — Gini), where the Gini is just the probability of a specific attribute being calculated correctly if selected randomly. So for instance, if we have a pure population, we know that if we randomly select any two items from that population they must be in the same class and therefore have a probability of 1. Therefore, the higher the value of the Gini, the more homogenous the data is. The attribute with the minimum Gini index is chosen as the splitting attribute.

Decision Tree Pseudocode

If we have a question at hand that needs to be answered, we can use model the decision-making process with a tree following these steps:

  1. First, we will start at the root node, which will be instantiated with all of the training instances. We will also set a current pointer and use it to traverse through our tree based on the decisions we make.
  2. For each feature (going back to our date example, this would be if the first date was bad or enjoyable): (1) Split the data at the current node by the value of the attribute and (2) Compute the information gain or entropy or Gini index from the splitting.
  3. Identify the attribute that results in the smallest Gini Index or highest information gain. If the best information gain ratio is 0 (the node is pure), we can tag it as a leaf node and continue.
  4. Split the data according to the best attribute we identified. Denote each split as a child node of the current node.
  5. For each child node that we just created: (1) If the child node is ‘pure’, then we know to tag it as a leaf and return and (2) If not, we can set the child node as the current node and repeat the process from step 2.

For an in-depth explanation on the pseudocode as well as more information, I referenced this website from Carnegie Mellon: https://www.cs.cmu.edu/~bhiksha/courses/10-601/decisiontrees/

Now let's see what this looks like in R code!

Decision Tree Tutorial

First, we have to make sure our packages are installed and libraried (imported). We have two packages we want to load — datasets and rpart. “datasets” gives us access to a bunch of datasets that are useful in practicing machine learning techniques. Here, we will datasets to give us access to the iris dataset. “rpart” is a great package used by data scientists to create classification and regression trees (CART).

install.packages('rpart')
install.packages("datasets")
library(datasets)
library(rpart)

Now that we have loaded in the necessary packages, we need to read in the data and see what it looks like from a high level.

data(iris)
attach(iris)
summary(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.350 Median :1.300
## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
## Species
## setosa :50
## versicolor:50
## virginica :50

For the purposes of this exercise, we will treat Species as the target variable to predict, or response variable, and all the other variables as “predictors”.

From the summary description shown above, the dataset looks to be in good shape with the variables formatted correctly and no missing values. Note that with real data problems, cleaning is a huge part of the process and trees/random forest models can be particularly finicky about the format of the data it is modeling on.

Since we can skip the data cleaning process, let’s head straight to growing our tree using rpart.

fit <- rpart(Species~., data=iris, method="class")

rpart has a couple of different arguments we can use in the function call. For this exercise, we’ll just focus on the basic ones: formula, data, and method.

We specify formula with “Species~.”, indicating that Species is our response variable and everything else should be used to split the response (branches).

data just specifies what dataframe we want to use.

method is used to indicate the type of response and split method we should use. This can be either “anova”, “poisson”, “class”, or “exp”. Since our response variable is categorical, we set the method as “class”. If we wanted to run a regression on the data with a numerical variable, we would set the method to “anova”.

We’ll get to the control argument shortly — first, let’s see what the tree looks like.

par(mfrow = c(1,2), xpd = NA)
plot(fit)
text(fit)

Our tree makes two splits. The first segments the data based on whether the petal length is less than 2.45. This split represents the most “information gain”, such that this decision does the best job in separating out the classes. The second split segments the data based on whether the petal width is less than 1.75. In rpart, if the condition is true at a fork, go left.

rpart does us a favor by including native cross-validation in its models. We can access those results by using the following command:

print(fit$cptable)##     CP nsplit rel error xerror       xstd
## 1 0.50 0 1.00 1.14 0.05230679
## 2 0.44 1 0.50 0.61 0.06016090
## 3 0.01 2 0.06 0.08 0.02751969

This table shows different complexity and error value over a number of different splits. We can see that we achieve the smallest cross-validation error (xerror) with 2 splits, so that seems like a good number of splits to use for our tree.

Let’s say though, for the sake of this exercise, that we just wanted 1 split. While this may not make sense in this context, we need to remind ourselves that having more splits can often “overfit” the model to the training data and so there will always be a balance between bias and model flexibility.

We can “prune” the tree to the desired number of splits or complexity parameters. We see that the CP associated with 1 split is 0.44, so we use that in our prune argument.

pruned_fit <- prune(fit, cp=0.44)
par(mfrow = c(1,2), xpd = NA)
plot(pruned_fit)
text(pruned_fit)

We can see that this probably is not the best tree given that we’ll never even see the virginica species here! However, this is a good tool to use in larger datasets and with larger trees.

Now that we have the basics of classification trees figured out, let’s move on to random forests!

Random Forests

Now that we understand decision trees, we already have the building blocks of the random forest model. Random forests are a group of decision trees that can be used as an ensemble. In machine learning, an ensemble model is one that combines several base models to produce one prediction, often performing better than any of the individual models. This makes sense when you think about the wisdom of the crowds: If you wanted the answer to a question, would you rather ask just one person or ask 100 different people and choose the answer which most people voted on? The second is an example of majority voting, but this can often be further improved using probabilities and averaging the predicted class probabilities.

So how does the Random Forest do so much better than a decision tree if we use the same training data to build all of these trees? We don’t. Random trees use a method called Bagging to create a diverse set of trees. Back to our example, if you wanted the answer to a question, would you prefer to ask a lot of people with similar backgrounds and education or would you prefer a diverse group of individuals who bring a different set of knowledge? The latter set of people would seem to encompass a broader scope of information that would help us answer more questions correctly. This idea is the basis for using boosting and bagging when creating models.

Bagging

Bagging is essentially bootstrap aggregating. A bootstrap sample is a random sample of the dataset with replacement. The data points which are not used are often used to evaluate the OOB (out of bag error) for determining tuning parameters. For any given bootstrap sample, there is a 37% that a particular observation is not included.

Because decision trees are very sensitive to changes in one data, bootstrap samples result in very different tree structures. This makes sense because if even one split near the top of the tree is different, the rest of the tree will most likely be different as well. This method reduces the overall variance of the model by producing individual models that have a low correlation. Usually, the more bootstrap samples that are used, the lower the variance is until a plateau is reached.

Building a Diverse Random Forest

Random forests are essentially a bagging model with three additional tuning parameters

  1. Number of features selected at each split of the tree (m). By randomly choosing m different features to evaluate for each split, we are making the trees even more diverse and uncorrelated.
  2. Depth of the tree. We can either choose the minimum number of observations in the leaf nodes or set the depth of the tree.
  3. Number of trees in the forest. In general, the more trees you add, the better the model will be. However, at some point, adding more trees will only be minimally better and we must consider the computational costs and time it takes to run a random forest with too many trees.

To make a decision, the random forest takes an average of all of the predictions from each decision tree for regression tasks. For classification, the tree can either output a binary prediction and use a majority vote or output a probability that will be averaged to determine the class prediction (usually preferred).

Now let's try to build a random forest in R!

Random Forest Tutorial

Now that we know a bit about what a random forest is and how it works, let’s put it into action with code. There are a ton of different packages available in R to create random forest models such as randomForest, xgboost, randomForestSRC, RBorist just to name a few. randomForest is one of the most popular and has the least steep learning curve, so we’ll use that package in this exercise.

install.packages("randomForest")
library(randomForest)

We’ll use the same iris dataset as we did with the decision tree, so let’s load that in as well.

data(iris)
attach(iris)

Since we already have done some basic exploratory data analysis on the dataset, let’s jump straight to growing our random forest.

The randomForest call has a ton of different arguments that can be found in its manual by typing “?randomForest” — we’ll highlight just two additional ones that we did not cover in the tree section: ntree, mtry, and importance.

ntree specifies the number of trees to grow in our random forest. The more trees, the more different permutations of variables and data in our model. There is a balance between computational capacity, model performance, and overfitting that must be considered when deciding a value for ntree.

mtry is the number of variables randomly sampled as candidates at each split. In our data set, we only have 4 predictors. Typically, the default value for mtry is sqrt(p) where p is the number of predictors for a classification problem. If we were using regression, we would use p/3.

importance is a binary variable that lets us specify whether we should assess the importance of our predictors. This lets us use some really cool tools that we’ll explore later on.

rf <- randomForest(data=iris, Species~., ntree=100, mtry=2, importance=TRUE)

We could also specify a test set in the function call as xtest, which would have the randomForest perform predictions in place. We could access those predictions using rf$test. Since we are just using cross-validation here though, we won’t run that for now.

Let's take a look at our model.

rf## Call:
## randomForest(formula = Species ~ ., data = iris, ntree = 100, mtry = 2, importance = TRUE)
## Type of random forest: classification
## Number of trees: 100
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 5.33%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 5 45 0.10

We see the confusion matrix above that shows us how many observations in each class we correctly identified. Our model is great with setosa, having no class error there. It performs about evenly with versicolor and virginica. Overall, we have an excellent model! How did the randomForest split these classes though? We can find that out by taking a look at the variable importance.

varImpPlot(rf)

We can see how the randomForest ranks the different predictors in discerning power using both Accuracy and Gini. Separating on Petal Width leads to the highest mean decrease in accuracy across all trees, whereas separating on Petal Length leads to the highest mean decrease in Gini across all trees. Put simply, those are the best variables to separate the classes on (virginica, setosa, versicolor).

Housing Data Exercise

Now it’s time for you to try an exercise applying a Random Forest model to a practical application with housing data.

Let’s first load in the rsample package for easier data splitting, randomForest for the model, and AmesHousing for our dataset.

library(rsample)
library(randomForest)
library(AmesHousing)

We’ll create a training set (75%) and a testing set(25%) for the AmesHousing::make_ames() data and use set.seed for reproducibility.

set.seed(007)
ames_split <- initial_split(AmesHousing::make_ames(), prop = .75)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)

Now try creating a Random Forest model using randomForest with a maximum number of nodes of 100. Plot the model to illustrate the error rate as we average across more trees. Your output should look like this:

What number of trees produces the lowest MSE? What’s the RMSE of the optimal random forest? See if you can get this answer:

## [1] 245## [1] 34454.48

What happens to the model when you set the minimum number of nodes instead of the maximum at 100? The model should now look like this:

Would you predict the new RMSE is lower or higher? Why? The new RMSE is:

## [1] 24747.39

This all only scratches the surface of random forests — there is so much to learn and uncover. Hopefully, this gets you started on a path to using some of the coolest and most versatile tools available to data scientists!

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store