Path: blob/master/docs/introduction-to-matplotlib.ipynb
874 views
What is matplotlib?
Matplotlib is a visualization library for Python.
As in, if you want to display something in a chart or graph, matplotlib can help you do that programmatically.
Many of the graphics you'll see in machine learning research papers or presentations are made with matplotlib.
Why matplotlib?
Matplotlib is part of the standard Python data stack (pandas, NumPy, matplotlib, Jupyter).
It has terrific integration with many other Python libraries.
pandas uses matplotlib as a backend to help visualize data in DataFrames.
What does this notebook cover?
A central idea in matplotlib is the concept of a "plot" (hence the name).
So we're going to practice making a series of different plots, which is a way to visually represent data.
Since there are basically limitless ways to create a plot, we're going to focus on a making and customizing (making them look pretty) a few common types of plots.
Where can I get help?
If you get stuck or think of something you'd like to do which this notebook doesn't cover, don't fear!
The recommended steps you take are:
Try it - Since matplotlib is very friendly, your first step should be to use what you know and try figure out the answer to your own question (getting it wrong is part of the process). If in doubt, run your code.
Search for it - If trying it on your own doesn't work, since someone else has probably tried to do something similar, try searching for your problem in the following places (either via a search engine or direct):
matplotlib documentation - the best place for learning all of the vast functionality of matplotlib. Bonus: You can see a series of matplotlib cheatsheets on the matplotlib website.
Stack Overflow - this is the developers Q&A hub, it's full of questions and answers of different problems across a wide range of software development topics and chances are, there's one related to your problem.
ChatGPT - ChatGPT is very good at explaining code, however, it can make mistakes. Best to verify the code it writes first before using it. Try asking "Can you explain the following code for me? {your code here}" and then continue with follow up questions from there. But always be careful using generated code. Avoid blindly copying something you couldn't reproduce yourself with enough effort.
An example of searching for a matplotlib feature might be:
"how to colour the bars of a matplotlib plot"
Searching this on Google leads to this documentation page on the matplotlib website: https://matplotlib.org/stable/gallery/lines_bars_and_markers/bar_colors.html
The next steps here are to read through the post and see if it relates to your problem. If it does, great, take the code/information you need and rewrite it to suit your own problem.
Ask for help - If you've been through the above 2 steps and you're still stuck, you might want to ask your question on Stack Overflow or in the ZTM Discord chat. Remember to be specific as possible and provide details on what you've tried.
Remember, you don't have to learn all of these functions off by heart to begin with.
What's most important is remembering to continually ask yourself, "what am I trying to visualize?"
Start by answering that question and then practicing finding the code which does it.
Let's get to visualizing some data!
0. Importing matplotlib
We'll start by importing matplotlib.pyplot
.
Why pyplot
?
Because pyplot
is a submodule for creating interactive plots programmatically.
pyplot
is often imported as the alias plt
.
Note: In older notebooks and tutorials of matplotlib, you may see the magic command
%matplotlib inline
. This was required to view plots inside a notebook, however, as of 2020 it is mostly no longer required.
1. 2 ways of creating plots
There are two main ways of creating plots in matplotlib.
matplotlib.pyplot.plot()
- Recommended for simple plots (e.g. x and y).matplotlib.pyplot.XX
(where XX can be one of many methods, this is known as the object-oriented API) - Recommended for more complex plots (for exampleplt.subplots()
to create multiple plots on the same Figure, we'll get to this later).
Both of these methods are still often created by building off import matplotlib.pyplot as plt
as a base.
Let's start simple.
A few quick things about a plot:
x
is the horizontal axis.y
is the vertical axis.In a data point,
x
usually comes first, e.g.(3, 4)
would be(x=3, y=4)
.The same is happens in
matplotlib.pyplot.plot()
,x
comes beforey
, e.g.plt.plot(x, y)
.
Now let's try using the object-orientated version.
We'll start by creating a figure with plt.figure()
.
And then we'll add an axes with add_subplot
.
A note on the terminology:
A
Figure
(e.g.fig = plt.figure()
) is the final image in matplotlib (and it may contain one or moreAxes
), often shortened tofig
.The
Axes
are an individual plot (e.g.ax = fig.add_subplot()
), often shorted toax
.One
Figure
can contain one or moreAxes
.
The
Axis
are x (horizontal), y (vertical), z (depth).
Now let's add some data to our pevious plot.
But there's an easier way we can use matplotlib.pyplot
to help us create a Figure
with multiple potential Axes
.
And that's with plt.subplots()
.
Anatomy of a Matplotlib Figure
Matplotlib offers almost unlimited options for creating plots.
However, let's break down some of the main terms.
Figure - The base canvas of all matplotlib plots. The overall thing you're plotting is a Figure, often shortened to
fig
.Axes - One Figure can have one or multiple Axes, for example, a Figure with multiple suplots could have 4 Axes (2 rows and 2 columns). Often shortened to
ax
.Axis - A particular dimension of an Axes, for example, the x-axis or y-axis.
A quick Matplotlib Workflow
The following workflow is a standard practice when creating a matplotlib plot:
Import matplotlib - For example,
import matplotlib.pyplot as plt
).Prepare data - This may be from an existing dataset (data analysis) or from the outputs of a machine learning model (data science).
Setup the plot - In other words, create the Figure and various Axes.
Plot data to the Axes - Send the relevant data to the target Axes.
Cutomize the plot - Add a title, decorate the colours, label each Axis.
Save (optional) and show - See what your masterpiece looks like and save it to file if necessary.
2. Making the most common type of plots using NumPy arrays
Most of figuring out what kind of plot to use is getting a feel for the data, then seeing what kind of plot suits it best.
Matplotlib visualizations are built on NumPy arrays. So in this section we'll build some of the most common types of plots using NumPy arrays.
Line plot -
ax.plot()
(this is the default plot in matplotlib)Scatter plot -
ax.scatter()
Bar plot -
ax.bar()
Histogram plot -
ax.hist()
We'll see how all of these can be created as a method from matplotlob.pyplot.subplots()
.
Resource: Remember you can see many of the different kinds of matplotlib plot types in the documentation.
To make sure we have access to NumPy, we'll import it as np
.
Creating a line plot
Line is the default type of visualization in Matplotlib. Usually, unless specified otherwise, your plots will start out as lines.
Line plots are great for seeing trends over time.
Creating a scatter plot
Scatter plots can be great for when you've got many different individual data points and you'd like to see how they interact with eachother without being connected.
Creating a histogram plot
Histogram plots are excellent for showing the distribution of data.
For example, you might want to show the distribution of ages of a population or wages of city.
Creating Figures with multiple Axes with Subplots
Subplots allow you to create multiple Axes on the same Figure (multiple plots within the same plot).
Subplots are helpful because you start with one plot per Figure but scale it up to more when necessary.
For example, let's create a subplot that shows many of the above datasets on the same Figure.
We can do so by creating multiple Axes with plt.subplots()
and setting the nrows
(number of rows) and ncols
(number of columns) parameters to reflect how many Axes we'd like.
nrows
and ncols
parameters are multiplicative, meaning plt.subplots(nrows=2, ncols=2)
will create 2*2=4
total Axes.
Resource: You can see a sensational number of examples for creating Subplots in the matplotlib documentation.
3. Plotting data directly with pandas
Matplotlib has a tight integration with pandas too.
You can directly plot from a pandas DataFrame with DataFrame.plot()
.
Let's see the following plots directly from a pandas DataFrame:
Line
Scatter
Bar
Hist
To plot data with pandas, we first have to import it as pd
.
Now we need some data to check out.
Line plot from a pandas DataFrame
To understand examples, I often find I have to repeat them (code them myself) rather than just read them.
To begin understanding plotting with pandas, let's recreate the a section of the pandas Chart visualization documents.
Great! We've got some random values across time.
Now let's add up the data cumulatively overtime with DataFrame.cumsum()
(cumsum
is short for cumulative sum or continaully adding one thing to the next and so on).
We can now visualize the values by calling the plot()
method on the DataFrame and specifying the kind of plot we'd like with the kind
parameter.
In our case, the kind we'd like is a line plot, hence kind="line"
(this is the default for the plot()
method).
Working with actual data
Let's do a little data manipulation on our car_sales
DataFrame.
Scatter plot from a pandas DataFrame
You can create scatter plots from a pandas DataFrame by using the kind="scatter"
parameter.
However, you'll often find that certain plots require certain kinds of data (e.g. some plots require certain columns to be numeric).
Having the Price
column as an int
returns a much better looking y-axis.
Bar plot from a pandas DataFrame
Let's see how we can plot a bar plot from a pandas DataFrame.
First, we'll create some data.
We can plot a bar chart directly with the bar()
method on the DataFrame.
And we can also do the same thing passing the kind="bar"
parameter to DataFrame.plot()
.
Let's try a bar plot on the car_sales
DataFrame.
This time we'll specify the x
and y
axis values.
Histogram plot from a pandas DataFrame
We can plot a histogram plot from our car_sales
DataFrame using DataFrame.plot.hist()
or DataFrame.plot(kind="hist")
.
Histograms are great for seeing the distribution or the spread of data.
Changing the bins
parameter we can put our data into different numbers of collections.
For example, by default bins=10
(10 groups of data), let's see what happens when we change it to bins=20
.
To practice, let's create a histogram of the Price
column.
And to practice even further, how about we try another dataset?
Namely, let's create some plots using the heart disease dataset we've worked on before.
What does this tell you about the spread of heart disease data across different ages?
Creating a plot with multiple Axes from a pandas DataFrame
We can also create a series of plots (multiple Axes on one Figure) from a DataFrame using the subplots=True
parameter.
First, let's remind ourselves what the data looks like.
Since all of our columns are numeric in value, let's try and create a histogram of each column.
Hmmm... is this a very helpful plot?
Perhaps not.
Sometimes you can visualize too much on the one plot and it becomes confusing.
Best to start with less and gradually increase.
4. Plotting more advanced plots from a pandas DataFrame
It's possible to achieve far more complicated and detailed plots from a pandas DataFrame.
Let's practice using the heart_disease
DataFrame.
And as an example, let's do some analysis on people over 50 years of age.
To do so, let's start by creating a plot directly from pandas and then using the object-orientated API (plt.subplots()
) to build upon it.
Now let's create a scatter plot directly from the pandas DataFrame.
This is quite easy to do but is a bit limited in terms of customization.
Let's visualize patients over 50 cholesterol levels.
We can visualize which patients have or don't have heart disease by colouring the samples to be in line with the target
column (e.g. 0
= no heart disease, 1
= heart disease).
We can recreate the same plot using plt.subplots()
and then passing the Axes variable (ax
) to the pandas plot()
method.
Now instead of plotting directly from the pandas DataFrame, we can make a bit more of a comprehensive plot by plotting data directly to a target Axes instance.
What if we wanted a horizontal line going across with the mean of heart_disease["chol"]
?
We do so with the Axes.axhline()
method.
Plotting multiple plots on the same figure (adding another plot to an existing one)
Sometimes you'll want to visualize multiple features of a dataset or results of a model in one Figure.
You can achieve this by adding data to multiple Axes on the same Figure.
The plt.subplots()
method helps you create Figures with a desired number of Axes in a desired figuration.
Using nrows
(number of rows) and ncols
(number of columns) parameters you can control the number of Axes on the Figure.
For example:
nrows=2
,ncols=1
= 2x1 = a Figure with 2 Axesnrows=5
,ncols=5
= 5x5 = a Figure with 25 Axes
Let's create a plot with 2 Axes.
One the first Axes (Axes 0), we'll plot heart disease against cholesterol levels (chol
).
On the second Axes (Axis 1), we'll plot heart disease against max heart rate levels (thalach
).
5. Customizing your plots (making them look pretty)
If you're not a fan of the default matplotlib styling, there are plenty of ways to make your plots look prettier.
The more visually appealing your plot, the higher the chance people are going to want to look at them.
However, be careful not to overdo the customizations, as they may hinder the information being conveyed.
Some of the things you can customize include:
Axis limits - The range in which your data is displayed.
Colors - That colors appear on the plot to represent different data.
Overall style - Matplotlib has several different styles built-in which offer different overall themes for your plots, you can see examples of these in the matplotlib style sheets reference documentation.
Legend - One of the most informative pieces of information on a Figure can be the legend, you can modify the legend of an Axes with the
plt.legend()
method.
Let's start by exploring different styles built into matplotlib.
Customizing the style of plots
Matplotlib comes with several built-in styles that are all created with an overall theme.
You can see what styles are available by using plt.style.available
.
Resources:
To see what many of the available styles look like, you can refer to the matplotlib style sheets reference documentation.
For a deeper guide on customizing, refer to the Customizing Matplotlib with style sheets and rcParams tutorial.
Before we change the style of a plot, let's remind ourselves what the default plot style looks like.
Wonderful!
Now let's change the style of our future plots using the plt.style.use(style)
method.
Where the style
parameter is one of the available matplotlib styles.
How about we try "seaborn-v0_8-whitegrid"
(seaborn is another common visualization library built on top of matplotlib)?
Wonderful!
Notice the slightly different styling of the plot?
Some styles change more than others.
How about we try "fivethirtyeight"
?
Ohhh that's a nice looking plot!
Does the style carry over for another type of plot?
How about we try a scatter plot?
It does!
Looks like we may need to adjust the spacing on our x-axis though.
What about another style?
Let's try "ggplot"
.
Cool!
Now how can we go back to the default style?
Hint: with "default"
.
Customizing the title, legend and axis labels
When you have a matplotlib Figure or Axes object, you can customize many of the attributes by using the Axes.set()
method.
For example, you can change the:
xlabel
- Labels on the x-axis.ylim
- Limits of the y-axis.xticks
- Style of the x-ticks.much more in the documentation.
Rather than talking about it, let's practice!
First, we'll create some random data and then put it into a DataFrame.
Then we'll make a plot from that DataFrame and see how to customize it.
Now let's plot the data from the DataFrame in a bar chart.
This time we'll save the plot to a variable called ax
(short for Axes).
Excellent!
We can see the type of our ax
variable is of AxesSubplot
which allows us to use all of the methods available in matplotlib for Axes
.
Let's set a few attributes of the plot with the set()
method.
Namely, we'll change the title
, xlabel
and ylabel
to communicate what's being displayed.
Notice the legend is up in the top left corner by default, we can change that if we like with the loc
parameter of the legend()
method.
loc
can be set as a string to reflect where the legend should be.
By default it is set to loc="best"
which means matplotlib will try to figure out the best positioning for it.
Let's try changing it to "loc="upper right"
.
Nice!
Is that a better fit?
Perhaps not, but it goes to show how you can change the legend position if needed.
Customizing the colours of plots with colormaps (cmap)
Colour is one of the most important features of a plot.
It can help to separate different kinds of information.
And with the right colours, plots can be fun to look at and try to learn more.
Matplotlib provides many different colour options through matplotlib.colormaps
.
Let's see how we can change the colours of a matplotlib plot via the cmap
parameter (cmap
is short for colormaps
).
We'll start by creating a scatter plot with the default cmap
value (cmap="viridis"
).
Wonderful!
That plot doesn't look too bad.
But what if we wanted to change the colours?
There are many different cmap
parameter options available in the colormap reference.
How about we try cmap="winter"
?
We can also change the colour of the horizontal line using the color
parameter and setting it to a string of the colour we'd like (e.g. color="r"
for red).
Woohoo!
The first plot looked nice, but I think I prefer the colours of this new plot better.
For more on choosing colormaps in matplotlib, there's a sensational and in-depth tutorial in the matplotlib documentation.
Customizing the xlim & ylim
Matplotlib is pretty good at setting the ranges of values on the x-axis and the y-axis.
But as you might've guessed, you can customize these to suit your needs.
You can change the ranges of different axis values using the xlim
and ylim
parameters inside of the set()
method.
To practice, let's recreate our double Axes plot from before with the default x-axis and y-axis values.
We'll add in the colour updates from the previous section too.
Now let's recreate the plot from above but this time we'll change the axis limits.
We can do so by using Axes.set(xlim=[50, 80])
or Axes.set(ylim=[60, 220])
where the inputs to xlim
and ylim
are a list of integers defining a range of values.
For example, xlim=[50, 80]
will set the x-axis values to start at 50
and end at 80
.
Now that's a nice looking plot!
Let's figure out how we'd save it.
6. Saving plots
Once you've got a nice looking plot that you're happy with, the next thing is going to be sharing it with someone else.
In a report, blog post, presentation or something similar.
You can save matplotlib Figures with plt.savefig(fname="your_plot_file_name")
where fname
is the target filename you'd like to save the plot to.
Before we save our plot, let's recreate it.
Nice!
We can save our plots to several different kinds of filetypes.
And we can check these filetypes with fig.canvas.get_supported_filetypes()
.
Image filetypes such as jpg
and png
are excellent for blog posts and presentations.
Where as the pgf
or pdf
filetypes may be better for reports and papers.
One last look at our Figure, which is saved to the fig
variable.
Beautiful!
Now let's save it to file.
File saved!
Let's try and display it.
We can do so with the HTML code:
And changing the cell below to markdown.
Note: Because the plot is highly visual, it's import to make sure there is an
alt="some_text_here"
tag available when displaying the image, as this tag is used to make the plot more accessible to those with visual impairments. For more on displaying images with HTML, see the Mozzila documentation.
Finally, if we wanted to start making more and different Figures, we can reset our fig
variable by creating another plot.
If you're creating plots and saving them like this often, to save writing excess code, you might put it into a function.
A function which follows the Matplotlib workflow.
Extra resources
We've covered a fair bit here.
But really we've only scratched the surface of what's possible with matplotlib.
So for more, I'd recommend going through the following:
Matplotlib quick start guide - Try rewriting all the code in this guide to get familiar with it.
Matplotlib plot types guide - Inside you'll get an idea of just how many kinds of plots are possible with matplotlib.
Matplotlib lifecycle of a plot guide - A sensational ground-up walkthrough of the many different things you can do with a plot.