Implement Decision Trees in Python with Scikit-learn

Hannah Davis
InstructorHannah Davis

Share this video with your friends

Send Tweet
Published 5 years ago
Updated 3 years ago

We’ll learn about decision trees, also known as CART (classification and regression trees), and use them to explore a dataset of breast cancer tumors.

We'll also see how to visualize a decision tree using graphviz.

Instructor: [00:00] From sklearn, we'll import datasets. We'll import our metrics. We'll import the train_test_split function. From sklearn, we'll import tree for our decision tree. We'll also import Graphviz. We'll be working with the breast cancer dataset, which is datasets.load_breast_cancer.

[00:36] Let's print some information about this dataset. We can print the keys. We'll print the feature names and the target names. We can see we have quite a few features here describing the tumor. Our target categories are whether the tumor is malignant or benign.

[01:10] We'll assign X to be Y is Then we'll make our training and test data by saying X_train, X_test, y_train, and y_test equals train_test_split, pass it our X and y, and a test size, which is the amount of data that we want to go into the test dataset, so we'll say 15 percent, and a random state, which will be 33.

[01:46] We can print X_train.shape, and y_train.shape, and X_test.shape, and y_test.shape. We can see that our data is 483 data points with 30 features each. Our test data is 86 data points with 30 features each. We have 483 target labels and 86 target labels for the test data.

[02:23] From here, we can say model equals tree.decision_tree_classifier. Then we can say and pass our training data. We can make some predictions by saying model.predict and pass in our X_test data. If we print those predictions, we can see a whole bunch of labels labeled either malignant or benign.

[02:55] Like usual, we can print our model.score with our X_test and y_test variables. We can print our classification report with our accurate labels and our predictions. We can print our confusion matrix with our accurate labels and our predictions. We can see the model got most right but misclassified a couple in each class.

[03:33] But what decision trees are really good for is helping us figure out which variables are most important. To help understand this, we're going to visualize the decision tree. We'll do this by saying graph_data equals tree.export_graphviz. We pass our model.

[03:55] We say out_file equals None. We'll say feature_names equals breast_cancer.feature_names. Then we'll say graph equals graphviz.Source, our graph data. Then we'll say graph.render_breast_cancer, view equals True. What this will do, it will immediately create and pop up a PDF of our decision tree.

[04:26] How the decision tree works is it evaluates each variable that we have according to some metric. One popular one is called information gain. Another one is this value genie right here. This basically tells us which of our features are most helpful to predict these classes, with the most valuable features appearing toward the root of the decision tree.

[04:49] We can include more argument here, which is filled equals True. The colors help show the majority class for each box.

[05:00] We can see here this box at the root of the decision tree is checking if the worst parameter is smaller than 105.95. It's checking all 483 samples from our training data, and then showing that for 184 of those samples the value is True and for 299 of those samples the value of worst parameter being smaller than 105.95 is False.

[05:24] When the genie value gets to 0that means all samples will be in either one class or another, as in the case of this box here where 0of the samples are malignant and 12 are benign.

David Arias
David Arias
~ 4 years ago

I'm not able to import 'graphviz' to my jupyter notebook. I'm getting an error message that the module "graphviz" doesn't exist.

I've tried the following: import graphviz from graphviz import Digraph from graphviz import * import pygraphviz

all have come up moot.

I checked the graphviz docs and it is supposed to be compatible with jupyter notebook, and they use the same import command you do. Any ideas?

Shawn Wang
Shawn Wang
~ 3 years ago

did you install graphviz at all? brew install graphviz