Chapter 6 Data Visualization
There are many libraries that will allow you to produce impressive data visualizations in Python. Two of the most popular for Data Science are matplotlib
and seaborn
. Although matplotlib
is very popular, it is pretty low-level, in that you have to define much of each plot manually. Let’s take a brief look at it, since you’ll likely see it in popular use.
## Example data for plotting
Here, I’ll use one of my favourite data sets for example plotting: mtcars
.
= pd.read_csv('data/mtcars.csv')
mtcars print(mtcars.info())
## <class 'pandas.core.frame.DataFrame'>
## RangeIndex: 32 entries, 0 to 31
## Data columns (total 12 columns):
## # Column Non-Null Count Dtype
## --- ------ -------------- -----
## 0 model 32 non-null object
## 1 mpg 32 non-null float64
## 2 cyl 32 non-null int64
## 3 disp 32 non-null float64
## 4 hp 32 non-null int64
## 5 drat 32 non-null float64
## 6 wt 32 non-null float64
## 7 qsec 32 non-null float64
## 8 vs 32 non-null int64
## 9 am 32 non-null int64
## 10 gear 32 non-null int64
## 11 carb 32 non-null int64
## dtypes: float64(5), int64(6), object(1)
## memory usage: 3.1+ KB
## None
6.1 Plotting packages
Import the packages as such:
import seaborn as sns
import matplotlib.pyplot as plt
As an aside, we use the alias sns
for seaborn
because it was names after the Westwing character Samual Norman Seaborn. seaborn
is built on matplotlib
, so you’ll need to have imported anyway.
6.2 matplotlib
6.2.1 Scatter plots
A basic scatter plot follows the form:
"wt"], mtcars["mpg"], alpha=0.65)
plt.scatter(mtcars['A basic scatter plot')
plt.title('weight')
plt.xlabel('miles per gallon')
plt.ylabel( plt.show()
We can also assign the output of plt.subplots
to:
fig
, a container that holds everything you see on the past, andax
, the part of the page that holds the data, i.e. the canvas.
# Create a Figure and an Axes with plt.subplots
= plt.subplots(1, 3, sharex=True, sharey=True)
fig, ax print(ax.shape)
# Make plot
## (3,)
for key, value in enumerate(mtcars["cyl"].unique()):
"wt"][mtcars["cyl"] == value], mtcars["mpg"][mtcars["cyl"] == value], alpha=0.65)
ax[key].scatter(mtcars[
# Call the show function
plt.show()
That’s not nice, but not terrible. Here’s an example of a simple dodged bar plot from the official matplotlib
gallery page:
# Bar chart demo with pairs of bars grouped for easy comparison.
import numpy as np
import matplotlib.pyplot as plt
= 5
n_groups
= (20, 35, 30, 35, 27)
means_men = (2, 3, 4, 1, 2)
std_men
= (25, 32, 34, 20, 25)
means_women = (3, 5, 2, 3, 3)
std_women
= plt.subplots()
fig, ax
= np.arange(n_groups)
index = 0.35
bar_width
= 0.4
opacity = {'ecolor': '0.3'}
error_config
= plt.bar(index, means_men, bar_width,
rects1 =opacity,
alpha='b',
color=std_men,
yerr=error_config,
error_kw='Men')
label
= plt.bar(index + bar_width, means_women, bar_width,
rects2 =opacity,
alpha='r',
color=std_women,
yerr=error_config,
error_kw='Women')
label
'Group')
plt.xlabel('Scores')
plt.ylabel('Scores by group and gender')
plt.title(+ bar_width / 2, ('A', 'B', 'C', 'D', 'E')) plt.xticks(index
## ([<matplotlib.axis.XTick object at 0x7f97fa3811d0>, <matplotlib.axis.XTick object at 0x7f97fa3b1da0>, <matplotlib.axis.XTick object at 0x7f97fa3b19e8>, <matplotlib.axis.XTick object at 0x7f97fa4179b0>, <matplotlib.axis.XTick object at 0x7f97fa40ec50>], [Text(0.175, 0, 'A'), Text(1.175, 0, 'B'), Text(2.175, 0, 'C'), Text(3.175, 0, 'D'), Text(4.175, 0, 'E')])
plt.legend()
plt.tight_layout() plt.show()
There are some great examples to be found in the gallery, but to be fair, for the most part, you just want to get your data plotted and get on with your analysis. This all seems like way too much work! Let’s see if seaborn can offer a solution.
6.3 seaborn
seaborn
is build on to top matplotlib
and takes care of a lot of the leg work for you!
6.3.1 Scatter plots
Let’s revisit our scatter plot:
="wt", y="mpg", data = mtcars)
sns.scatterplot(x
# Alternatively, using pandas Series:
# sns.scatterplot(mtcars["wt"], mtcars["mpg"])
# prior to v0.9.0
# sns.regplot(mtcars["wt"], mtcars["mpg"])
We can also add another variable encoded as color. Notice that seaborn
also puts the legend for us!
="wt", y="mpg", hue="cyl", data = mtcars) sns.scatterplot(x
But this looks pretty bad, what’s the issue here? Well the colours are pretty terrible.
= mtcars.astype({"cyl": category})
mtcars #mtcars = mtcars.astype({"cyl": 'category'})
#sns.scatterplot(x="wt", y="mpg", hue="cyl", data = mtcars)
There are two issues with the colors. First, we’d like to change the colors and make them nices. Second, we need to use a discrete scale, since this plot gives the impression that there is such as thing as a 5 cylinder car. Ideally, we’d like to plot a categorical variable, which means we’d change the type, as such:
'cyl'] = mtcars['cyl'].astype(object)
mtcars[='wt', y='mpg', hue='cyl', data = mtcars) sns.scatterplot(x
This would work if cyl
was a character, but when plotting seaborn
checks if the type can be coerced to a numeric, and if it is is plots a numeric and then there is a mis-match to our colors, so it doesn’t work. Annoying! The work around is to just assign a color palette as a list.
= pd.read_csv('data/mtcars.csv')
mtcars ='wt', y='mpg', hue='cyl', palette=["r", "b", "g"], data = mtcars) sns.scatterplot(x
We can also use hex colors, which are web-safe and easier to keep consistent.
='wt', y='mpg', hue='cyl', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars) sns.scatterplot(x
There is a great tutorial on built-in color palettes.
We can add an aditional variable using the size:
='wt', y='mpg', hue='cyl', size='disp', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars) sns.scatterplot(x
= plt.subplots()
fig, ax = sns.relplot(x='wt', y='mpg', hue='cyl', size='disp', col='gear', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars)
ax plt.show()
6.3.2 Bar charts
How about a bar chart? Here we have just the counts in each group:
="cyl", data =mtcars)
sns.countplot(x# Alternatively, we can access the column directly as a pandas Series:
# sns.countplot(mtcars["cyl"])
But we can also plot the means and error bars
="cyl", y="wt", hue="am", data=mtcars) sns.barplot(x
6.3.3 Dot plots
But looking at the raw values is also useful:
= plt.subplots()
fig, ax = sns.catplot(x="cyl", y="wt", hue="am", kind="swarm", data=mtcars)
ax plt.show()