Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mrdbourke
GitHub Repository: mrdbourke/zero-to-mastery-ml
Path: blob/master/section-2-data-science-and-ml-tools/introduction-to-matplotlib-video.ipynb
874 views
Kernel: Python 3 (ipykernel)

Introduction to Matplotlib

Get straight into plotting data, that's what we're focused on.

Video 0 will be concepts and contain details like anatomy of a figure. The rest of the videos will be pure code based.

  1. Concepts in Matplotlib

  2. 2 ways of creating plots (pyplot & OO) - use the OO method

  3. Plotting data (NumPy arrays), line, scatter, bar, hist, subplots

  4. Plotting data directly with Pandas (using the pandas matplotlib wrapper)

  5. Plotting data (pandas DataFrames) with the OO method, line, scatter, bar, hist, subplots

  6. Cutomizing your plots, limits, colors, styles, legends

  7. Saving plots

0. Concepts in Matplotlib

  • What is Matplotlib?

  • Why Matplotlib?

  • Anatomy of a figure

  • Where does Matplotlib fit into the ecosystem?

    • A Matplotlib workflow

1. 2 ways of creating plots

Start by importing Matplotlib and setting up the %matplotlib inline magic command.

# Import matplotlib and setup the figures to display within the notebook %matplotlib inline import matplotlib.pyplot as plt
# Create a simple plot, without the semi-colon plt.plot()
[]
Image in a Jupyter notebook
# With the semi-colon plt.plot();
Image in a Jupyter notebook
# You could use plt.show() if you want plt.plot() plt.show()
Image in a Jupyter notebook
# Let's add some data plt.plot([1, 2, 3, 4])
[<matplotlib.lines.Line2D at 0x11366a490>]
Image in a Jupyter notebook
# Create some data x = [1, 2, 3, 4] y = [11, 22, 33, 44]
# With a semi-colon and now a y value plt.plot(x, y);
Image in a Jupyter notebook
# Creating a plot with the OO verison, confusing way first fig = plt.figure() ax = fig.add_subplot() plt.show()
Image in a Jupyter notebook
# Confusing #2 fig = plt.figure() ax = fig.add_axes([1, 1, 1, 1]) ax.plot(x, y) plt.show()
Image in a Jupyter notebook
# Easier and more robust going forward (what we're going to use) fig, ax = plt.subplots() ax.plot(x, y);
Image in a Jupyter notebook

-> Show figure/plot anatomy here <-

# This is where the object orientated name comes from type(fig), type(ax)
(matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot)
# A matplotlib workflow # 0. Import and get matplotlib ready %matplotlib inline import matplotlib.pyplot as plt # 1. Prepare data x = [1, 2, 3, 4] y = [11, 22, 33, 44] # 2. Setup plot fig, ax = plt.subplots(figsize=(10,10)) # 3. Plot data ax.plot(x, y) # 4. Customize plot ax.set(title="Sample Simple Plot", xlabel="x-axis", ylabel="y-axis") # 5. Save & show fig.savefig("../images/simple-plot.png")
Image in a Jupyter notebook

2. Making the most common type of plots using NumPy arrays

Most of figuring out what kind of plot to use is getting a feel for the data, then see what suits it best.

Matplotlib visualizations are built off NumPy arrays. So in this section we'll build some of the most common types of plots using NumPy arrays.

  • line

  • scatter

  • bar

  • hist

  • subplots()

To make sure we have access to NumPy, we'll import it as np.

import numpy as np

Line

Line is the default type of visualization in Matplotlib. Usually, unless specified otherwise, your plots will start out as lines.

# Create an array x = np.linspace(0, 10, 100) x[:10]
array([0. , 0.1010101 , 0.2020202 , 0.3030303 , 0.4040404 , 0.50505051, 0.60606061, 0.70707071, 0.80808081, 0.90909091])
# The default plot is line fig, ax = plt.subplots() ax.plot(x, x**2);
Image in a Jupyter notebook

Scatter

# Need to recreate our figure and axis instances when we want a new figure fig, ax = plt.subplots() ax.scatter(x, np.exp(x));
Image in a Jupyter notebook
fig, ax = plt.subplots() ax.scatter(x, np.sin(x));
Image in a Jupyter notebook

Bar

  • Vertical

  • Horizontal

# You can make plots from a dictionary nut_butter_prices = {"Almond butter": 10, "Peanut butter": 8, "Cashew butter": 12} fig, ax = plt.subplots() ax.bar(nut_butter_prices.keys(), nut_butter_prices.values()) ax.set(title="Dan's Nut Butter Store", ylabel="Price ($)");
Image in a Jupyter notebook
fig, ax = plt.subplots() ax.barh(list(nut_butter_prices.keys()), list(nut_butter_prices.values()));
Image in a Jupyter notebook

Histogram (hist)

  • Could show image of normal distribution here

# Make some data from a normal distribution x = np.random.randn(1000) # pulls data from a normal distribution fig, ax = plt.subplots() ax.hist(x);
Image in a Jupyter notebook
x = np.random.random(1000) # random data from random distribution fig, ax = plt.subplots() ax.hist(x);
Image in a Jupyter notebook
# Option 1: Create multiple subplots fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(10, 5)) # Plot data to each axis ax1.plot(x, x/2); ax2.scatter(np.random.random(10), np.random.random(10)); ax3.bar(nut_butter_prices.keys(), nut_butter_prices.values()); ax4.hist(np.random.randn(1000));
Image in a Jupyter notebook
# Option 2: Create multiple subplots fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 5)) # Index to plot data ax[0, 0].plot(x, x/2); ax[0, 1].scatter(np.random.random(10), np.random.random(10)); ax[1, 0].bar(nut_butter_prices.keys(), nut_butter_prices.values()); ax[1, 1].hist(np.random.randn(1000));
Image in a Jupyter notebook

3. Plotting data directly with pandas

This section uses the pandas pd.plot() method on a DataFrame to plot columns directly.

To plot data with pandas, we first have to import it as pd.

import pandas as pd

Now we need some data to check out.

# Let's import the car_sales dataset car_sales = pd.read_csv("../data/car-sales.csv") car_sales

Line

  • Concept

  • DataFrame

Often, reading things won't make sense. Practice writing code for yourself, get it out of the docs and into your workspace. See what happens when you run it.

Let's start with trying to replicate the pandas visualization documents.

# Start with some dummy data ts = pd.Series(np.random.randn(1000), index=pd.date_range('1/1/2020', periods=1000)) ts
2020-01-01 0.738301 2020-01-02 -0.436335 2020-01-03 1.552973 2020-01-04 -0.721055 2020-01-05 -0.522301 ... 2022-09-22 -0.529207 2022-09-23 -0.760224 2022-09-24 0.399311 2022-09-25 -0.669529 2022-09-26 0.238585 Freq: D, Length: 1000, dtype: float64
# What does cumsum() do? ts.cumsum()
2020-01-01 0.738301 2020-01-02 0.301966 2020-01-03 1.854938 2020-01-04 1.133883 2020-01-05 0.611582 ... 2022-09-22 -36.324290 2022-09-23 -37.084515 2022-09-24 -36.685204 2022-09-25 -37.354733 2022-09-26 -37.116148 Freq: D, Length: 1000, dtype: float64
ts.cumsum().plot();
Image in a Jupyter notebook

Working with actual data

Let's do a little data manipulation on our car_sales DataFrame.

# Remove price column symbols car_sales["Price"] = car_sales["Price"].str.replace('[\$\,\.]', '') car_sales
# Remove last two zeros car_sales["Price"] = car_sales["Price"].str[:-2] car_sales
# Add a date column car_sales["Sale Date"] = pd.date_range("1/1/2020", periods=len(car_sales)) car_sales
# Make total sales column (doesn't work, adds as string) #car_sales["Total Sales"] = car_sales["Price"].cumsum() # Oops... want them as int's not string car_sales["Total Sales"] = car_sales["Price"].astype(int).cumsum() car_sales
car_sales.plot(x='Sale Date', y='Total Sales');
Image in a Jupyter notebook

Scatter

  • Concept

  • DataFrame

# Doesn't work car_sales.plot(x="Odometer (KM)", y="Price", kind="scatter")
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-34-540f318a89d0> in <module> 1 # Doesn't work ----> 2 car_sales.plot(x="Odometer (KM)", y="Price", kind="scatter") ~/Desktop/ml-course/work-in-progress/env/lib/python3.7/site-packages/pandas/plotting/_core.py in __call__(self, *args, **kwargs) 736 if kind in self._dataframe_kinds: 737 if isinstance(data, ABCDataFrame): --> 738 return plot_backend.plot(data, x=x, y=y, kind=kind, **kwargs) 739 else: 740 raise ValueError( ~/Desktop/ml-course/work-in-progress/env/lib/python3.7/site-packages/pandas/plotting/_matplotlib/__init__.py in plot(data, kind, **kwargs) 59 ax = plt.gca() 60 kwargs["ax"] = getattr(ax, "left_ax", ax) ---> 61 plot_obj = PLOT_CLASSES[kind](data, **kwargs) 62 plot_obj.generate() 63 plot_obj.draw() ~/Desktop/ml-course/work-in-progress/env/lib/python3.7/site-packages/pandas/plotting/_matplotlib/core.py in __init__(self, data, x, y, s, c, **kwargs) 928 # the handling of this argument later 929 s = 20 --> 930 super().__init__(data, x, y, s=s, **kwargs) 931 if is_integer(c) and not self.data.columns.holds_integer(): 932 c = self.data.columns[c] ~/Desktop/ml-course/work-in-progress/env/lib/python3.7/site-packages/pandas/plotting/_matplotlib/core.py in __init__(self, data, x, y, **kwargs) 870 raise ValueError(self._kind + " requires x column to be numeric") 871 if len(self.data[y]._get_numeric_data()) == 0: --> 872 raise ValueError(self._kind + " requires y column to be numeric") 873 874 self.x = x ValueError: scatter requires y column to be numeric
# Convert Price to int car_sales["Price"] = car_sales["Price"].astype(int) car_sales.plot(x="Odometer (KM)", y="Price", kind='scatter');
Image in a Jupyter notebook

Bar

  • Concept

  • DataFrame

x = np.random.rand(10, 4) x
array([[0.91054912, 0.65668407, 0.75347508, 0.1488774 ], [0.4739657 , 0.65199569, 0.80087623, 0.25613654], [0.20515965, 0.14991211, 0.07454593, 0.15030318], [0.17102306, 0.97405707, 0.69580935, 0.41898253], [0.22654692, 0.1848998 , 0.01482526, 0.0647843 ], [0.54732069, 0.68484856, 0.71222659, 0.70537797], [0.50304196, 0.68331734, 0.0471555 , 0.94868537], [0.96833686, 0.19313494, 0.11765464, 0.13561539], [0.96998806, 0.50634506, 0.02096006, 0.32375073], [0.30732541, 0.10588319, 0.72021475, 0.07767541]])
df = pd.DataFrame(x, columns=['a', 'b', 'c', 'd']) df
df.plot.bar();
Image in a Jupyter notebook
# Can do the same thing with 'kind' keyword df.plot(kind='bar');
Image in a Jupyter notebook
car_sales.plot(x='Make', y='Odometer (KM)', kind='bar');
Image in a Jupyter notebook

Histograms

car_sales["Odometer (KM)"].plot.hist();
Image in a Jupyter notebook
car_sales["Odometer (KM)"].plot(kind="hist");
Image in a Jupyter notebook
# Default number of bins is 10 car_sales["Odometer (KM)"].plot.hist(bins=20);
Image in a Jupyter notebook
car_sales["Price"].plot.hist(bins=10);
Image in a Jupyter notebook
# Let's try with another dataset heart_disease = pd.read_csv("../data/heart-disease.csv") heart_disease.head()
heart_disease["age"].plot.hist(bins=50);
Image in a Jupyter notebook

Subplots

  • Concept

  • DataFrame

heart_disease.head()
heart_disease.plot.hist(figsize=(10, 30), subplots=True);
Image in a Jupyter notebook

4. Plotting with pandas using the OO method

For more complicated plots, you'll want to use the OO method.

# Perform data analysis on patients over 50 over_50 = heart_disease[heart_disease["age"] > 50] over_50
over_50.plot(kind='scatter', x='age', y='chol', c='target', figsize=(10, 6));
Image in a Jupyter notebook
fig, ax = plt.subplots(figsize=(10, 6)) over_50.plot(kind='scatter', x="age", y="chol", c='target', ax=ax); ax.set_xlim([45, 100]);
Image in a Jupyter notebook
# Make a bit more of a complicated plot # Create the plot fig, ax = plt.subplots(figsize=(10, 6)) # Plot the data scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) # Customize the plot ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.legend(*scatter.legend_elements(), title="Target");
Image in a Jupyter notebook

What if we wanted a horizontal line going across with the mean of heart_disease["chol"]?

https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.axhline.html

# Make a bit more of a complicated plot # Create the plot fig, ax = plt.subplots(figsize=(10, 6)) # Plot the data scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) # Customize the plot ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.legend(*scatter.legend_elements(), title="Target") # Add a meanline ax.axhline(over_50["chol"].mean(), linestyle="--");
Image in a Jupyter notebook

Adding another plot to existing styled one

# Setup plot (2 rows, 1 column) fig, (ax0, ax1) = plt.subplots(nrows=2, # 2 rows ncols=1, sharex=True, figsize=(10, 8)) # Add data for ax0 scatter = ax0.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) # Customize ax0 ax0.set(title="Heart Disease and Cholesterol Levels", ylabel="Cholesterol") ax0.legend(*scatter.legend_elements(), title="Target") # Setup a mean line ax0.axhline(y=over_50["chol"].mean(), color='b', linestyle='--', label="Average") # Add data for ax1 scatter = ax1.scatter(over_50["age"], over_50["thalach"], c=over_50["target"]) # Customize ax1 ax1.set(title="Heart Disease and Max Heart Rate Levels", xlabel="Age", ylabel="Max Heart Rate") ax1.legend(*scatter.legend_elements(), title="Target") # Setup a mean line ax1.axhline(y=over_50["thalach"].mean(), color='b', linestyle='--', label="Average") # Title the figure fig.suptitle('Heart Disease Analysis', fontsize=16, fontweight='bold');
Image in a Jupyter notebook

5. Customizing your plots

  • limits (xlim, ylim), colors, styles, legends

Style

plt.style.available
['seaborn-dark', 'seaborn-darkgrid', 'seaborn-ticks', 'fivethirtyeight', 'seaborn-whitegrid', 'classic', '_classic_test', 'fast', 'seaborn-talk', 'seaborn-dark-palette', 'seaborn-bright', 'seaborn-pastel', 'grayscale', 'seaborn-notebook', 'ggplot', 'seaborn-colorblind', 'seaborn-muted', 'seaborn', 'Solarize_Light2', 'seaborn-paper', 'bmh', 'tableau-colorblind10', 'seaborn-white', 'dark_background', 'seaborn-poster', 'seaborn-deep']
# Plot before changing style car_sales["Price"].plot();
Image in a Jupyter notebook
# Change the style... plt.style.use('seaborn-whitegrid')
car_sales["Price"].plot();
Image in a Jupyter notebook
plt.style.use('seaborn')
car_sales["Price"].plot();
Image in a Jupyter notebook
car_sales.plot(x="Odometer (KM)", y="Price", kind="scatter");
Image in a Jupyter notebook
plt.style.use('ggplot')
car_sales["Price"].plot.hist();
Image in a Jupyter notebook

Changing the title, legend, axes

x = np.random.randn(10, 4) x
array([[-1.45604975, 0.44398039, 1.21617191, 0.46778121], [ 0.09043707, 0.64565222, -0.92772261, 0.63044677], [-0.20260212, -0.30685306, -1.07970088, 0.13664664], [ 2.01577535, 1.49857223, -0.05013591, -1.24773112], [-1.1872596 , -0.73286209, 0.45447678, 1.9601397 ], [ 1.24279567, 0.03839697, 0.1417006 , 0.50332953], [-0.76513327, -0.14311738, -0.32238378, 1.02932238], [ 0.37193522, 0.44785763, 0.85386339, 0.60622919], [ 1.44272823, -0.86638843, -0.48638364, -0.30357948], [-0.90626553, -0.44139532, -1.99812976, 0.83367383]])
df = pd.DataFrame(x, columns=['a', 'b', 'c', 'd']) df
ax = df.plot(kind='bar') type(ax)
matplotlib.axes._subplots.AxesSubplot
Image in a Jupyter notebook
ax = df.plot(kind='bar') ax.set(title="Random Number Bar Graph from DataFrame", xlabel="Row number", ylabel="Random number") ax.legend().set_visible(True)
Image in a Jupyter notebook

Changing the cmap

plt.style.use('seaborn-whitegrid')
# No cmap change fig, ax = plt.subplots(figsize=(10, 6)) scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.axhline(y=over_50["chol"].mean(), c='b', linestyle='--', label="Average"); ax.legend(*scatter.legend_elements(), title="Target");
Image in a Jupyter notebook
# Change cmap and horizontal line to be a different colour fig, ax = plt.subplots(figsize=(10, 6)) scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"], cmap="winter") ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol") ax.axhline(y=over_50["chol"].mean(), color='r', linestyle='--', label="Average"); ax.legend(*scatter.legend_elements(), title="Target");
Image in a Jupyter notebook

Changing the xlim & ylim

## Before the change (we've had color updates) fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 10)) scatter = ax0.scatter(over_50["age"], over_50["chol"], c=over_50["target"], cmap='winter') ax0.set(title="Heart Disease and Cholesterol Levels", ylabel="Cholesterol") # Setup a mean line ax0.axhline(y=over_50["chol"].mean(), color='r', linestyle='--', label="Average"); ax0.legend(*scatter.legend_elements(), title="Target") # Axis 1, 1 (row 1, column 1) scatter = ax1.scatter(over_50["age"], over_50["thalach"], c=over_50["target"], cmap='winter') ax1.set(title="Heart Disease and Max Heart Rate Levels", xlabel="Age", ylabel="Max Heart Rate") # Setup a mean line ax1.axhline(y=over_50["thalach"].mean(), color='r', linestyle='--', label="Average"); ax1.legend(*scatter.legend_elements(), title="Target") # Title the figure fig.suptitle('Heart Disease Analysis', fontsize=16, fontweight='bold');
Image in a Jupyter notebook
## After adding in different x & y limitations fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 10)) scatter = ax0.scatter(over_50["age"], over_50["chol"], c=over_50["target"], cmap='winter') ax0.set(title="Heart Disease and Cholesterol Levels", ylabel="Cholesterol") # Set the x axis ax0.set_xlim([50, 80]) # Setup a mean line ax0.axhline(y=over_50["chol"].mean(), color='r', linestyle='--', label="Average"); ax0.legend(*scatter.legend_elements(), title="Target") # Axis 1, 1 (row 1, column 1) scatter = ax1.scatter(over_50["age"], over_50["thalach"], c=over_50["target"], cmap='winter') ax1.set(title="Heart Disease and Max Heart Rate Levels", xlabel="Age", ylabel="Max Heart Rate") # Set the y axis ax1.set_ylim([60, 200]) # Setup a mean line ax1.axhline(y=over_50["thalach"].mean(), color='r', linestyle='--', label="Average"); ax1.legend(*scatter.legend_elements(), title="Target") # Title the figure fig.suptitle('Heart Disease Analysis', fontsize=16, fontweight='bold');
Image in a Jupyter notebook

6. Saving plots

  • Saving plots to images using figsave()

If you're doing something like this often, to save writing excess code, you might put it into a function.

A function which follows the Matplotlib workflow.

# Axis 0, 1 (row 0, column 0) fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 10)) scatter = ax0.scatter(over_50["age"], over_50["chol"], c=over_50["target"], cmap='winter') ax0.set(title="Heart Disease and Cholesterol Levels", ylabel="Cholesterol") # Set the x axis ax0.set_xlim([50, 80]) # Setup a mean line ax0.axhline(y=over_50["chol"].mean(), color='r', linestyle='--', label="Average"); ax0.legend(*scatter.legend_elements(), title="Target") # Axis 1, 1 (row 1, column 1) scatter = ax1.scatter(over_50["age"], over_50["thalach"], c=over_50["target"], cmap='winter') ax1.set(title="Heart Disease and Max Heart Rate Levels", xlabel="Age", ylabel="Max Heart Rate") # Set the y axis ax1.set_ylim([60, 200]) # Setup a mean line ax1.axhline(y=over_50["thalach"].mean(), color='r', linestyle='--', label="Average"); ax1.legend(*scatter.legend_elements(), title="Target") # Title the figure fig.suptitle('Heart Disease Analysis', fontsize=16, fontweight='bold');
Image in a Jupyter notebook
# Check the supported filetypes fig.canvas.get_supported_filetypes()
{'ps': 'Postscript', 'eps': 'Encapsulated Postscript', 'pdf': 'Portable Document Format', 'pgf': 'PGF code for LaTeX', 'png': 'Portable Network Graphics', 'raw': 'Raw RGBA bitmap', 'rgba': 'Raw RGBA bitmap', 'svg': 'Scalable Vector Graphics', 'svgz': 'Scalable Vector Graphics'}
fig
Image in a Jupyter notebook
# Save the file fig.savefig("../images/heart-disease-analysis.png")
# Resets figure fig, ax = plt.subplots()
Image in a Jupyter notebook
# Potential function def plotting_workflow(data): # 1. Manipulate data # 2. Create plot # 3. Plot data # 4. Customize plot # 5. Save plot # 6. Return plot return plot