Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mrdbourke
GitHub Repository: mrdbourke/zero-to-mastery-ml
Path: blob/master/section-2-data-science-and-ml-tools/matplotlib-exercises-solutions.ipynb
874 views
Kernel: Python 3

Matplotlib Practice (solutions)

This notebook offers a set of solutions to different tasks with Matplotlib.

It should be noted there may be more than one different way to answer a question or complete an exercise.

Different tasks will be detailed by comments or text.

For further reference and resources, it's advised to check out the Matplotlib documnetation.

If you're stuck, don't forget, you can always search for a function, for example if you want to create a plot with plt.subplots(), search for plt.subplots().

# Import the pyplot module from matplotlib as plt and make sure # plots appear in the notebook using '%matplotlib inline' %matplotlib inline import matplotlib.pyplot as plt
# Create a simple plot using plt.plot() plt.plot()
[]
Image in a Jupyter notebook
# Plot a single Python list plt.plot([2, 6, 8, 17])
[<matplotlib.lines.Line2D at 0x11333e860>]
Image in a Jupyter notebook
# Create two lists, one called X, one called y, each with 5 numbers in them X = [22, 88, 98, 103, 45] y = [7, 8, 9, 10, 11]
# Plot X & y (the lists you've created) plt.plot(X, y)
[<matplotlib.lines.Line2D at 0x10578f278>]
Image in a Jupyter notebook

There's another way to create plots with Matplotlib, it's known as the object-orientated (OO) method. Let's try it.

# Create a plot using plt.subplots() fig, ax = plt.subplots()
Image in a Jupyter notebook
# Create a plot using plt.subplots() and then add X & y on the axes fig, ax = plt.subplots() ax.plot(X, y)
[<matplotlib.lines.Line2D at 0x1134b0588>]
Image in a Jupyter notebook

Now let's try a small matplotlib workflow.

# Import and get matplotlib ready %matplotlib inline import matplotlib.pyplot as plt # Prepare data (create two lists of 5 numbers, X & y) X = [34, 77, 21, 54, 9] y = [9, 45, 89, 66, 4] # Setup figure and axes using plt.subplots() fig, ax = plt.subplots() # Add data (X, y) to axes ax.plot(X, y) # Customize plot by adding a title, xlabel and ylabel ax.set(title="Sample simple plot", xlabel="x-axis", ylabel="y-axis") # Save the plot to file using fig.savefig() fig.savefig("../images/simple-plot.png")
Image in a Jupyter notebook

Okay, this is a simple line plot, how about something a little different?

To help us, we'll import NumPy.

# Import NumPy as np import numpy as np
# Create an array of 100 evenly spaced numbers between 0 and 100 using NumPy and save it to variable X X = np.linspace(0, 10, 100)
# Create a plot using plt.subplots() and plot X versus X^2 (X squared) fig, ax = plt.subplots() ax.plot(X, X**2)
[<matplotlib.lines.Line2D at 0x113570c50>]
Image in a Jupyter notebook

We'll start with scatter plots.

# Create a scatter plot of X versus the exponential of X (np.exp(X)) fig, ax = plt.subplots() ax.scatter(X, np.exp(X))
<matplotlib.collections.PathCollection at 0x11372b5f8>
Image in a Jupyter notebook
# Create a scatter plot of X versus np.sin(X) fig, ax = plt.subplots() ax.scatter(X, np.sin(X))
<matplotlib.collections.PathCollection at 0x1134a06a0>
Image in a Jupyter notebook

How about we try another type of plot? This time let's look at a bar plot. First we'll make some data.

# Create a Python dictionary of 3 of your favourite foods with # The keys of the dictionary should be the food name and the values their price favourite_food_prices = {"Almond butter": 10, "Blueberries": 5, "Eggs": 6}
# Create a bar graph where the x-axis is the keys of the dictionary # and the y-axis is the values of the dictionary fig, ax = plt.subplots() ax.bar(favourite_food_prices.keys(), favourite_food_prices.values()) # Add a title, xlabel and ylabel to the plot ax.set(title="Daniel's favourite foods", xlabel="Food", ylabel="Price ($)")
[Text(0, 0.5, 'Food'), Text(0.5, 0, 'Price ($)'), Text(0.5, 1.0, "Daniel's favourite foods")]
Image in a Jupyter notebook
# Make the same plot as above, except this time make the bars go horizontal fig, ax = plt.subplots() ax.barh(list(favourite_food_prices.keys()), list(favourite_food_prices.values())) ax.set(title="Daniel's favourite foods", xlabel="Price ($)", ylabel="Food")
[Text(0, 0.5, 'Food'), Text(0.5, 0, 'Price ($)'), Text(0.5, 1.0, "Daniel's favourite foods")]
Image in a Jupyter notebook

All this food plotting is making me hungry. But we've got a couple of plots to go.

Let's see a histogram.

# Create a random NumPy array of 1000 normally distributed numbers using NumPy and save it to X X = np.random.randn(1000) # Create a histogram plot of X fig, ax = plt.subplots() ax.hist(X)
(array([ 3., 19., 73., 184., 295., 269., 118., 34., 3., 2.]), array([-3.71922182, -2.93930767, -2.15939353, -1.37947939, -0.59956525, 0.1803489 , 0.96026304, 1.74017718, 2.52009133, 3.30000547, 4.07991961]), <a list of 10 Patch objects>)
Image in a Jupyter notebook
# Create a NumPy array of 1000 random numbers and save it to X X = np.random.random(1000) # Create a histogram plot of X fig, ax = plt.subplots() ax.hist(X)
(array([109., 110., 114., 93., 107., 99., 90., 86., 94., 98.]), array([4.33144794e-04, 1.00050572e-01, 1.99667999e-01, 2.99285426e-01, 3.98902853e-01, 4.98520281e-01, 5.98137708e-01, 6.97755135e-01, 7.97372562e-01, 8.96989989e-01, 9.96607416e-01]), <a list of 10 Patch objects>)
Image in a Jupyter notebook

Notice how the distributions (spread of data) are different. Why do they differ?

What else can you find out about the normal distribution?

Can you think of any other kinds of data which may be normally distributed?

These questions aren't directly related to plotting or Matplotlib but they're helpful to think of.

Now let's try make some subplots. A subplot is another name for a figure with multiple plots on it.

# Create an empty subplot with 2 rows and 2 columns (4 subplots total) fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
Image in a Jupyter notebook

Notice how the subplot has multiple figures. Now let's add data to each axes.

# Create the same plot as above with 2 rows and 2 columns and figsize of (10, 5) fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(10, 5)) # Plot X versus X/2 on the top left axes ax1.plot(X, X/2) # Plot a scatter plot of 10 random numbers on each axis on the top right subplot ax2.scatter(np.random.random(10), np.random.random(10)) # Plot a bar graph of the favourite food keys and values on the bottom left subplot ax3.bar(favourite_food_prices.keys(), favourite_food_prices.values()) # Plot a histogram of 1000 random normally distributed numbers on the bottom right subplot ax4.hist(np.random.randn(1000));
Image in a Jupyter notebook

Woah. There's a lot going on there.

Now we've seen how to plot with Matplotlib and data directly. Let's practice using Matplotlib to plot with pandas.

First we'll need to import pandas and create a DataFrame work with.

# Import pandas as pd import pandas as pd
# Import the '../data/car-sales.csv' into a DataFame called car_sales and view car_sales = pd.read_csv("../data/car-sales.csv") car_sales
# Try to plot the 'Price' column using the plot() function car_sales["Price"].plot()
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-23-628fda813398> in <module> 1 # Try to plot the 'Price' column using the plot() function ----> 2 car_sales["Price"].plot() ~/Desktop/ml-course/zero-to-mastery-ml/env/lib/python3.6/site-packages/pandas/plotting/_core.py in __call__(self, *args, **kwargs) 792 data.columns = label_name 793 --> 794 return plot_backend.plot(data, kind=kind, **kwargs) 795 796 def line(self, x=None, y=None, **kwargs): ~/Desktop/ml-course/zero-to-mastery-ml/env/lib/python3.6/site-packages/pandas/plotting/_matplotlib/__init__.py in plot(data, kind, **kwargs) 60 kwargs["ax"] = getattr(ax, "left_ax", ax) 61 plot_obj = PLOT_CLASSES[kind](data, **kwargs) ---> 62 plot_obj.generate() 63 plot_obj.draw() 64 return plot_obj.result ~/Desktop/ml-course/zero-to-mastery-ml/env/lib/python3.6/site-packages/pandas/plotting/_matplotlib/core.py in generate(self) 277 def generate(self): 278 self._args_adjust() --> 279 self._compute_plot_data() 280 self._setup_subplots() 281 self._make_plot() ~/Desktop/ml-course/zero-to-mastery-ml/env/lib/python3.6/site-packages/pandas/plotting/_matplotlib/core.py in _compute_plot_data(self) 412 # no non-numeric frames or series allowed 413 if is_empty: --> 414 raise TypeError("no numeric data to plot") 415 416 # GH25587: cast ExtensionArray of pandas (IntegerArray, etc.) to TypeError: no numeric data to plot

Why doesn't it work?

Hint: It's not numeric data.

In the process of turning it to numeric data, let's create another column which adds the total amount of sales and another one which shows what date the car was sold.

Hint: To add a column up cumulatively, look up the cumsum() function. And to create a column of dates, look up the date_range() function.

# Remove the symbols, the final two numbers from the 'Price' column and convert it to numbers car_sales["Price"] = car_sales["Price"].str.replace("[\$\,\.]", "") car_sales["Price"] = car_sales["Price"].str[:-2]
# Add a column called 'Total Sales' to car_sales which cumulatively adds the 'Price' column car_sales["Total Sales"] = car_sales["Price"].astype(int).cumsum() # Add a column called 'Sale Date' which lists a series of successive dates starting from today (your today) car_sales["Sale Date"] = pd.date_range("13/1/2020", periods=len(car_sales)) # View the car_sales DataFrame car_sales

Now we've got a numeric column (Total Sales) and a dates column (Sale Date), let's visualize them.

# Use the plot() function to plot the 'Sale Date' column versus the 'Total Sales' column car_sales.plot(x="Sale Date", y="Total Sales")
<matplotlib.axes._subplots.AxesSubplot at 0x115a48f60>
Image in a Jupyter notebook
# Convert the 'Price' column to the integers car_sales["Price"] = car_sales["Price"].astype(int) # Create a scatter plot of the 'Odometer (KM)' and 'Price' column using the plot() function car_sales.plot(x="Odometer (KM)", y="Price", kind="scatter")
<matplotlib.axes._subplots.AxesSubplot at 0x113d14cf8>
Image in a Jupyter notebook
# Create a NumPy array of random numbers of size (10, 4) and save it to X X = np.random.rand(10, 4) # Turn the NumPy array X into a DataFrame with columns called ['a', 'b', 'c', 'd'] df = pd.DataFrame(X, columns=["a", "b", "c", "d"]) # Create a bar graph of the DataFrame df.plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x115cb87f0>
Image in a Jupyter notebook
# Create a bar graph of the 'Make' and 'Odometer (KM)' columns in the car_sales DataFrame car_sales.plot(x="Make", y="Odometer (KM)", kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x115e37fd0>
Image in a Jupyter notebook
# Create a histogram of the 'Odometer (KM)' column car_sales["Odometer (KM)"].plot(kind="hist")
<matplotlib.axes._subplots.AxesSubplot at 0x115f729e8>
Image in a Jupyter notebook
# Create a histogram of the 'Price' column with 20 bins car_sales["Price"].plot.hist(bins=20)
<matplotlib.axes._subplots.AxesSubplot at 0x115a36b00>
Image in a Jupyter notebook

Now we've seen a few examples of plotting directly from DataFrames using the car_sales dataset.

Let's try using a different dataset.

# Import "../data/heart-disease.csv" and save it to the variable "heart_disease" heart_disease = pd.read_csv("../data/heart-disease.csv")
# View the first 10 rows of the heart_disease DataFrame heart_disease.head(10)
# Create a histogram of the "age" column with 50 bins heart_disease["age"].plot.hist(bins=50);
Image in a Jupyter notebook
# Call plot.hist() on the heart_disease DataFrame and toggle the # "subplots" parameter to True heart_disease.plot.hist(subplots=True);
Image in a Jupyter notebook

That plot looks pretty squished. Let's change the figsize.

# Call the same line of code from above except change the "figsize" parameter # to be (10, 30) heart_disease.plot.hist(figsize=(10, 30), subplots=True);
Image in a Jupyter notebook

Now let's try comparing two variables versus the target variable.

More specifially we'll see how age and cholesterol combined effect the target in patients over 50 years old.

For this next challenge, we're going to be replicating the following plot:

# Replicate the above plot in whichever way you see fit # Note: The method below is only one way of doing it, yours might be # slightly different # Create DataFrame with patients over 50 years old over_50 = heart_disease[heart_disease["age"] > 50] # Create the plot fig, ax = plt.subplots(figsize=(10, 6)) # Plot the data scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) # Customize the plot ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.legend(*scatter.legend_elements(), title="Target") # Add a meanline ax.axhline(over_50["chol"].mean(), linestyle="--");
Image in a Jupyter notebook

Beatiful, now you've created a plot of two different variables, let's change the style.

# Check what styles are available under plt plt.style.available
['seaborn-dark', 'seaborn-darkgrid', 'seaborn-ticks', 'fivethirtyeight', 'seaborn-whitegrid', 'classic', '_classic_test', 'fast', 'seaborn-talk', 'seaborn-dark-palette', 'seaborn-bright', 'seaborn-pastel', 'grayscale', 'seaborn-notebook', 'ggplot', 'seaborn-colorblind', 'seaborn-muted', 'seaborn', 'Solarize_Light2', 'seaborn-paper', 'bmh', 'tableau-colorblind10', 'seaborn-white', 'dark_background', 'seaborn-poster', 'seaborn-deep']
# Change the style to use "seaborn-whitegrid" plt.style.use("seaborn-whitegrid")

Now the style has been changed, we'll replot the same figure from above and see what it looks like.

If you've changed the style correctly, it should look like the following:

# Reproduce the same figure as above with the "seaborn-whitegrid" style # Create the plot fig, ax = plt.subplots(figsize=(10, 6)) # Plot the data scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"]) # Customize the plot ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.legend(*scatter.legend_elements(), title="Target") # Add a meanline ax.axhline(over_50["chol"].mean(), linestyle="--");
Image in a Jupyter notebook

Wonderful, you've changed the style of the plots and the figure is looking different but the dots aren't a very good colour.

Let's change the cmap parameter of scatter() as well as the color parameter of axhline() to fix it.

Completing this step correctly should result in a figure which looks like this:

# Replot the same figure as above except change the "cmap" parameter # of scatter() to "winter" # Also change the "color" parameter of axhline() to "red" # Create the plot fig, ax = plt.subplots(figsize=(10, 6)) # Plot the data scatter = ax.scatter(over_50["age"], over_50["chol"], c=over_50["target"], cmap="winter") # changed cmap parameter # Customize the plot ax.set(title="Heart Disease and Cholesterol Levels", xlabel="Age", ylabel="Cholesterol"); ax.legend(*scatter.legend_elements(), title="Target") # Add a meanline ax.axhline(over_50["chol"].mean(), linestyle="--", color="red"); # changed color parameter
Image in a Jupyter notebook

Beautiful! Now our figure has an upgraded color scheme let's save it to file.

# Save the current figure using savefig(), the file name can be anything you want fig.savefig("../images/matplotlib-heart-disease-chol-age-plot-saved.png")
# Reset the figure by calling plt.subplots() fig, ax = plt.subplots()
Image in a Jupyter notebook

Extensions

For more exercises, check out the Matplotlib tutorials page. A good practice would be to read through it and for the parts you find interesting, add them into the end of this notebook.

The next place you could go is the Stack Overflow page for the top questions and answers for Matplotlib. Often, you'll find some of the most common and useful Matplotlib functions here. Don't forget to play around with the Stack Overflow filters! You'll likely find something helpful here.

Finally, as always, remember, the best way to learn something new is to try it. And try it relentlessly. Always be asking yourself, "is there a better way this data could be visualized so it's easier to understand?"