Chapter 6 Data Visualization

There are many libraries that will allow you to produce impressive data visualizations in Python. Two of the most popular for Data Science are matplotlib and seaborn. Although matplotlib is very popular, it is pretty low-level, in that you have to define much of each plot manually. Let’s take a brief look at it, since you’ll likely see it in popular use. ## Example data for plotting

Here, I’ll use one of my favourite data sets for example plotting: mtcars.

mtcars = pd.read_csv('data/mtcars.csv')
print(mtcars.info())
## <class 'pandas.core.frame.DataFrame'>
## RangeIndex: 32 entries, 0 to 31
## Data columns (total 12 columns):
##  #   Column  Non-Null Count  Dtype  
## ---  ------  --------------  -----  
##  0   model   32 non-null     object 
##  1   mpg     32 non-null     float64
##  2   cyl     32 non-null     int64  
##  3   disp    32 non-null     float64
##  4   hp      32 non-null     int64  
##  5   drat    32 non-null     float64
##  6   wt      32 non-null     float64
##  7   qsec    32 non-null     float64
##  8   vs      32 non-null     int64  
##  9   am      32 non-null     int64  
##  10  gear    32 non-null     int64  
##  11  carb    32 non-null     int64  
## dtypes: float64(5), int64(6), object(1)
## memory usage: 3.1+ KB
## None

6.1 Plotting packages

Import the packages as such:

import seaborn as sns
import matplotlib.pyplot as plt

As an aside, we use the alias sns for seaborn because it was names after the Westwing character Samual Norman Seaborn. seaborn is built on matplotlib, so you’ll need to have imported anyway.

6.2 matplotlib

6.2.1 Scatter plots

A basic scatter plot follows the form:

plt.scatter(mtcars["wt"], mtcars["mpg"], alpha=0.65)
plt.title('A basic scatter plot')
plt.xlabel('weight')
plt.ylabel('miles per gallon')
plt.show()

We can also assign the output of plt.subplots to:

  • fig, a container that holds everything you see on the past, and
  • ax, the part of the page that holds the data, i.e. the canvas.
# Create a Figure and an Axes with plt.subplots
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True)
print(ax.shape)
# Make plot
## (3,)
for key, value in enumerate(mtcars["cyl"].unique()):
    ax[key].scatter(mtcars["wt"][mtcars["cyl"] == value], mtcars["mpg"][mtcars["cyl"] == value], alpha=0.65)

# Call the show function
plt.show()

That’s not nice, but not terrible. Here’s an example of a simple dodged bar plot from the official matplotlib gallery page:

# Bar chart demo with pairs of bars grouped for easy comparison.
import numpy as np
import matplotlib.pyplot as plt


n_groups = 5

means_men = (20, 35, 30, 35, 27)
std_men = (2, 3, 4, 1, 2)

means_women = (25, 32, 34, 20, 25)
std_women = (3, 5, 2, 3, 3)

fig, ax = plt.subplots()

index = np.arange(n_groups)
bar_width = 0.35

opacity = 0.4
error_config = {'ecolor': '0.3'}

rects1 = plt.bar(index, means_men, bar_width,
                 alpha=opacity,
                 color='b',
                 yerr=std_men,
                 error_kw=error_config,
                 label='Men')

rects2 = plt.bar(index + bar_width, means_women, bar_width,
                 alpha=opacity,
                 color='r',
                 yerr=std_women,
                 error_kw=error_config,
                 label='Women')

plt.xlabel('Group')
plt.ylabel('Scores')
plt.title('Scores by group and gender')
plt.xticks(index + bar_width / 2, ('A', 'B', 'C', 'D', 'E'))
## ([<matplotlib.axis.XTick object at 0x7f97fa3811d0>, <matplotlib.axis.XTick object at 0x7f97fa3b1da0>, <matplotlib.axis.XTick object at 0x7f97fa3b19e8>, <matplotlib.axis.XTick object at 0x7f97fa4179b0>, <matplotlib.axis.XTick object at 0x7f97fa40ec50>], [Text(0.175, 0, 'A'), Text(1.175, 0, 'B'), Text(2.175, 0, 'C'), Text(3.175, 0, 'D'), Text(4.175, 0, 'E')])
plt.legend()

plt.tight_layout()
plt.show()

There are some great examples to be found in the gallery, but to be fair, for the most part, you just want to get your data plotted and get on with your analysis. This all seems like way too much work! Let’s see if seaborn can offer a solution.

6.3 seaborn

seaborn is build on to top matplotlib and takes care of a lot of the leg work for you!

6.3.1 Scatter plots

Let’s revisit our scatter plot:

sns.scatterplot(x="wt", y="mpg", data = mtcars)

# Alternatively, using pandas Series:
# sns.scatterplot(mtcars["wt"], mtcars["mpg"])

# prior to v0.9.0
# sns.regplot(mtcars["wt"], mtcars["mpg"])

We can also add another variable encoded as color. Notice that seaborn also puts the legend for us!

sns.scatterplot(x="wt", y="mpg", hue="cyl", data = mtcars)

But this looks pretty bad, what’s the issue here? Well the colours are pretty terrible.

mtcars = mtcars.astype({"cyl": category})
#mtcars = mtcars.astype({"cyl": 'category'})
#sns.scatterplot(x="wt", y="mpg", hue="cyl", data = mtcars)

There are two issues with the colors. First, we’d like to change the colors and make them nices. Second, we need to use a discrete scale, since this plot gives the impression that there is such as thing as a 5 cylinder car. Ideally, we’d like to plot a categorical variable, which means we’d change the type, as such:

mtcars['cyl'] = mtcars['cyl'].astype(object)
sns.scatterplot(x='wt', y='mpg', hue='cyl', data = mtcars)

This would work if cyl was a character, but when plotting seaborn checks if the type can be coerced to a numeric, and if it is is plots a numeric and then there is a mis-match to our colors, so it doesn’t work. Annoying! The work around is to just assign a color palette as a list.

mtcars = pd.read_csv('data/mtcars.csv')
sns.scatterplot(x='wt', y='mpg', hue='cyl', palette=["r", "b", "g"], data = mtcars)

We can also use hex colors, which are web-safe and easier to keep consistent.

sns.scatterplot(x='wt', y='mpg', hue='cyl', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars)

There is a great tutorial on built-in color palettes.

We can add an aditional variable using the size:

sns.scatterplot(x='wt', y='mpg', hue='cyl', size='disp', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars)

fig, ax = plt.subplots()
ax = sns.relplot(x='wt', y='mpg', hue='cyl', size='disp', col='gear', palette=['#fc9272', '#ef3b2c', '#a50f15'], data = mtcars)
plt.show()

6.3.2 Bar charts

How about a bar chart? Here we have just the counts in each group:

sns.countplot(x="cyl", data =mtcars)
# Alternatively, we can access the column directly as a pandas Series:
# sns.countplot(mtcars["cyl"])

But we can also plot the means and error bars

sns.barplot(x="cyl", y="wt", hue="am", data=mtcars)

6.3.3 Dot plots

But looking at the raw values is also useful:

fig, ax = plt.subplots()
ax = sns.catplot(x="cyl", y="wt", hue="am", kind="swarm", data=mtcars)
plt.show()