Join egghead, unlock knowledge.

Want more egghead?

This lesson is for members. Join us? Get access to all 3,000+ tutorials + a community with expert developers around the world.

Unlock This Lesson

Already subscribed? Sign In


    Implement a Naive Bayes Classifier in Python and Scikit-learn to Categorize Text

    Hannah DavisHannah Davis

    We’ll use this probabilistic classifier to classify text into different news groups.

    There are several types of Naive Bayes classifiers in scikit-learn. We will be using the Multinomial Naive Bayes model, which is appropriate for text classification. More can be found at Scikit-learn.

    We'll also look at how to visualize the confusion matrix using pandas_ml.

    To install pandas_ml, type:

    $ pip install pandas_ml

    into your terminal, or install it with your installer of choice.



    Become a Member to view code

    You must be a Member to view code

    Access all courses and lessons, track your progress, gain confidence and expertise.

    Become a Member
    and unlock code for this lesson


    Instructor: 00:00 From sklearn, we'll import our datasets. We'll import metrics. From sklearn.feature_extraction.text, we'll import the TfidfVectorizer, which will help make our text understandable to the model. Then from sklearn.naive_bayes, we'll import the multinomial Naive Bayes.

    00:36 We'll also import matplotlib.pyplot as plt. If you'd like to visualize the confusion matrix at the end of this, from pandas_ml import ConfusionMatrix.

    00:54 We're going to be working with the newsgroups dataset. We access this a little differently. Newsgroups_train will be datasets.fetch_20newsgroups, and we'll pass in an argument subset='train', and newsgroups_test = datasets.fetch_20newsgroups(subset='test'). The data has already been split into training and test datasets for us. Let's explore this dataset a little bit.

    01:37 We see the common keys. Let's print now a couple items of the data and a couple target labels. Let's print our category names or target names. We can see that each data point is a bunch of text. We have three category labels in our target. The target names include these 20 categories of news text. We have baseball, for sale, motorcycles, religious talk, political talk, etc.

    02:18 The next thing we need to do is vectorize our text, and this means turn it from words into a model-understandable vector of features represented by numbers. To do this, we can say vectorizer = TfidfVectorizer. Tfidf stands for term frequency-inverse document frequency. It's a metric commonly used for analyzing text. We'll be using this as the lens through which to examine our text.

    02:49 We'll say X_train = vectorizer.fit_transform( X_test will be vectorizer.transform( Our y_train will be our Our y_test will be From there, we can say model = MultinomialNB.

    03:31 We can say our X training data and our y training data. Then our predictions will be model.predict our X_test data. We can print our model.score and metrics.classification_report with our accurate labels and our predictions. For 20 classes, that's not bad.

    04:09 One last thing we can do is visualize our confusion matrix. For this many classes, let's make the visualization a little bit better. We'll use pandas_ml to do that.

    04:19 First, we're going to make a variable called labels, which is a list of our newsgroups_train.target_names, so those are the category names, and then our confusion_matrix = ConfusionMatrix. We pass in our accurate labels, our predictions, and the label names. We say cm.plot and, finally,

    04:55 This confusion matrix helps us a lot. We can see that in general the model does a decent job predicting most of these categories. It also highlights areas where there was confusion very easily.

    05:07 We can see that there were a lot of predictions for religion.christian for other categories. Looking at the categories, it kind of makes sense. The atheism category probably has a big overlap, medicine a tiny bit. General, miscellaneous religious talk will also have a lot of overlap. This is a great tool to see what you can focus on and optimize next.