Matrix Plots

Matrix plots allow you to plot data as color-encoded matrices and can also be used to indicate clusters within the data (later in the machine learning section we will learn how to formally cluster data).

Let’s begin by exploring seaborn’s heatmap and clutermap:

import seaborn as sns
%matplotlib inline
flights = sns.load_dataset('flights')
tips = sns.load_dataset('tips')
tips.head()

total_billtipsexsmokerdaytimesize
016.991.01FemaleNoSunDinner2
110.341.66MaleNoSunDinner3
221.013.50MaleNoSunDinner3
323.683.31MaleNoSunDinner2
424.593.61FemaleNoSunDinner4
flights.head()

yearmonthpassengers
01949January112
11949February118
21949March132
31949April129
41949May121

Heatmap

In order for a heatmap to work properly, your data should already be in a matrix form, the sns.heatmap function basically just colors it in for you. For example:

tips.head()

total_billtipsexsmokerdaytimesize
016.991.01FemaleNoSunDinner2
110.341.66MaleNoSunDinner3
221.013.50MaleNoSunDinner3
323.683.31MaleNoSunDinner2
424.593.61FemaleNoSunDinner4
# Matrix form for correlation data
tips.corr()

total_billtipsize
total_bill1.0000000.6757340.598315
tip0.6757341.0000000.489299
size0.5983150.4892991.000000
sns.heatmap(tips.corr())
<matplotlib.axes._subplots.AxesSubplot at 0x7f66fe34a4e0>

png

sns.heatmap(tips.corr(),cmap='coolwarm',annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x7f66fc24a4e0>

png

Or for the flights data:

flights.pivot_table(values='passengers',index='month',columns='year')

year194919501951195219531954195519561957195819591960
month
January112115145171196204242284315340360417
February118126150180196188233277301318342391
March132141178193236235267317356362406419
April129135163181235227269313348348396461
May121125172183229234270318355363420472
June135149178218243264315374422435472535
July148170199230264302364413465491548622
August148170199242272293347405467505559606
September136158184209237259312355404404463508
October119133162191211229274306347359407461
November104114146172180203237271305310362390
December118140166194201229278306336337405432
pvflights = flights.pivot_table(values='passengers',index='month',columns='year')
sns.heatmap(pvflights)
<matplotlib.axes._subplots.AxesSubplot at 0x7f66fbab1eb8>

png

sns.heatmap(pvflights,cmap='magma',linecolor='white',linewidths=1)
<matplotlib.axes._subplots.AxesSubplot at 0x7f66fb9f0080>

png

clustermap

The clustermap uses hierarchal clustering to produce a clustered version of the heatmap. For example:

sns.clustermap(pvflights)
<seaborn.matrix.ClusterGrid at 0x7f66fba5a080>

png

Notice now how the years and months are no longer in order, instead they are grouped by similarity in value (passenger count). That means we can begin to infer things from this plot, such as August and July being similar (makes sense, since they are both summer travel months)

# More options to get the information a little clearer like normalization
sns.clustermap(pvflights,cmap='coolwarm',standard_scale=1)
<seaborn.matrix.ClusterGrid at 0x7f66fe335be0>

png

Greydon Gilmore
Greydon Gilmore
Electrophysiologist

My research interests include deep brain stimulation, machine learning and signal processing.

Previous
Next