7 Data Visualization
There are many libraries that will allow you to produce impressive data visualizations in Python. Two of the most popular for Data Science are matplotlib
and seaborn
, which are discussed in this chapter.
The ggplot2
equivalent in Python is the plotnine
library, which works in the same way, but lacks the extensive range of extension packages.
Import the packages as such:
import seaborn as sns
import matplotlib.pyplot as plt
As an aside, we use the alias sns
for seaborn
because it was names after the Westwing character Samual Norman Seaborn. seaborn
is built on matplotlib
, so you’ll need to have imported anyway.
We’ll use a simple data set that we already saw when learning the ggplot2
package in R: mtcars
.
= pd.read_csv('../00-data/mtcars.csv') mtcars
#> <class 'pandas.core.frame.DataFrame'>
#> RangeIndex: 32 entries, 0 to 31
#> Data columns (total 12 columns):
#> # Column Non-Null Count Dtype
#> --- ------ -------------- -----
#> 0 model 32 non-null object
#> 1 mpg 32 non-null float64
#> 2 cyl 32 non-null int64
#> 3 disp 32 non-null float64
#> 4 hp 32 non-null int64
#> 5 drat 32 non-null float64
#> 6 wt 32 non-null float64
#> 7 qsec 32 non-null float64
#> 8 vs 32 non-null int64
#> 9 am 32 non-null int64
#> 10 gear 32 non-null int64
#> 11 carb 32 non-null int64
#> dtypes: float64(5), int64(6), object(1)
#> memory usage: 3.1+ KB
#> None
7.1 The matplotlib
Library
Although matplotlib
is very popular, it is a low-level package, in that you have to define much of each plot manually. We’ll take a brief look at it, since you’ll likely see it in popular use.
7.1.1 Scatter plots
A basic scatter plot follows the form:
"wt"], mtcars["mpg"], alpha=0.65)
plt.scatter(mtcars['A basic scatter plot')
plt.title('weight')
plt.xlabel('miles per gallon')
plt.ylabel( plt.show()
We can also assign the output of plt.subplots
to:
-
fig
, a container that holds everything you see on the past, and -
ax
, the part of the page that holds the data, i.e. the canvas.
# Create a Figure and an Axes with plt.subplots
= plt.subplots(1, 3, sharex=True, sharey=True)
fig, ax print(ax.shape)
# Make plot
#> (3,)
for key, value in enumerate(mtcars["cyl"].unique()):
"wt"][mtcars["cyl"] == value], mtcars["mpg"][mtcars["cyl"] == value], alpha=0.65)
ax[key].scatter(mtcars[
# Call the show function
plt.show()
That’s not nice, but not terrible. Here’s an example of a simple dodged bar plot from the official matplotlib
gallery page:
# Bar chart demo with pairs of bars grouped for easy comparison.
import numpy as np
import matplotlib.pyplot as plt
= 5
n_groups
= (20, 35, 30, 35, 27)
means_men = (2, 3, 4, 1, 2)
std_men
= (25, 32, 34, 20, 25)
means_women = (3, 5, 2, 3, 3)
std_women
= plt.subplots()
fig, ax
= np.arange(n_groups)
index = 0.35
bar_width
= 0.4
opacity = {'ecolor': '0.3'}
error_config
= plt.bar(index, means_men, bar_width,
rects1 =opacity,
alpha='b',
color=std_men,
yerr=error_config,
error_kw='Men')
label
= plt.bar(index + bar_width, means_women, bar_width,
rects2 =opacity,
alpha='r',
color=std_women,
yerr=error_config,
error_kw='Women')
label
'Group')
plt.xlabel('Scores')
plt.ylabel('Scores by group and gender')
plt.title(+ bar_width / 2, ('A', 'B', 'C', 'D', 'E'))
plt.xticks(index #> ([<matplotlib.axis.XTick object at 0x2809d7fd0>, <matplotlib.axis.XTick object at 0x283c4b6d0>, <matplotlib.axis.XTick object at 0x280a8cfd0>, <matplotlib.axis.XTick object at 0x283c8e340>, <matplotlib.axis.XTick object at 0x283c8ea90>], [Text(0.175, 0, 'A'), Text(1.175, 0, 'B'), Text(2.175, 0, 'C'), Text(3.175, 0, 'D'), Text(4.175, 0, 'E')])
plt.legend()
plt.tight_layout() plt.show()
There are some great examples to be found in the gallery, but to be fair, for the most part, you just want to get your data plotted and get on with your analysis. This all seems like way too much work! Let’s see if seaborn
can offer a solution.
7.2 The seaborn
Library
seaborn
is built on top of matplotlib
and takes care of a lot of the leg work for you!
7.2.1 Scatter plots
Let’s revisit our scatter plot:
="wt", y="mpg", data = mtcars)
sns.scatterplot(x
# Alternatively, using pandas Series:
# sns.scatterplot(mtcars["wt"], mtcars["mpg"])
# prior to v0.9.0
# sns.regplot(mtcars["wt"], mtcars["mpg"])
We can also add another variable encoded as color. Notice that seaborn
also puts the legend for us!
="wt", y="mpg", hue="cyl", data = mtcars) sns.scatterplot(x
But this looks pretty bad, what’s the issue here? Well the colours are pretty terrible.
= mtcars.astype({"cyl": category})
mtcars #mtcars = mtcars.astype({"cyl": 'category'})
#sns.scatterplot(x="wt", y="mpg", hue="cyl", data = mtcars)
There are two issues with the colors. First, we’d like to change the colors and make them nices. Second, we need to use a discrete scale, since this plot gives the impression that there is such as thing as a 5 cylinder car. Ideally, we’d like to plot a categorical variable, which means we’d change the type, as such:
'cyl'] = mtcars['cyl'].astype(object)
mtcars[='wt', y='mpg', hue='cyl', data = mtcars) sns.scatterplot(x
This would work if cyl
was a character, but when plotting seaborn
checks if the type can be coerced to a numeric, and if it is is plots a numeric and then there is a mis-match to our colors, so it doesn’t work. Annoying! The work around is to just assign a color palette as a list.
= pd.read_csv('data/mtcars.csv')
mtcars ='wt', y='mpg', hue='cyl', palette=["r", "b", "g"], data = mtcars) sns.scatterplot(x
We can also use hex colors, which are web-safe and easier to keep consistent.
='wt', y='mpg', hue='cyl', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars) sns.scatterplot(x
There is a great tutorial on built-in color palettes.
We can add an aditional variable using the size:
='wt', y='mpg', hue='cyl', size='disp', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars) sns.scatterplot(x
= plt.subplots()
fig, ax = sns.relplot(x='wt', y='mpg', hue='cyl', size='disp', col='gear', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars)
ax plt.show()
7.2.2 Bar charts
How about a bar chart? Here we have just the counts in each group:
="cyl", data =mtcars)
sns.countplot(x# Alternatively, we can access the column directly as a pandas Series:
# sns.countplot(mtcars["cyl"])
But we can also plot the means and error bars
="cyl", y="wt", hue="am", data=mtcars) sns.barplot(x