Repository for a workshop on Bayesian statistics
"""This file contains code for use with "Think Stats",1by Allen B. Downey, available from greenteapress.com23Copyright 2014 Allen B. Downey4License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html5"""67from __future__ import print_function89import math10import matplotlib11import matplotlib.pyplot as plt12import numpy as np13import pandas1415import warnings1617# customize some matplotlib attributes18#matplotlib.rc('figure', figsize=(4, 3))1920#matplotlib.rc('font', size=14.0)21#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)22#matplotlib.rc('legend', fontsize=20.0)2324#matplotlib.rc('xtick.major', size=6.0)25#matplotlib.rc('xtick.minor', size=3.0)2627#matplotlib.rc('ytick.major', size=6.0)28#matplotlib.rc('ytick.minor', size=3.0)293031class _Brewer(object):32"""Encapsulates a nice sequence of colors.3334Shades of blue that look good in color and can be distinguished35in grayscale (up to a point).3637Borrowed from http://colorbrewer2.org/38"""39color_iter = None4041colors = ['#f7fbff', '#deebf7', '#c6dbef',42'#9ecae1', '#6baed6', '#4292c6',43'#2171b5','#08519c','#08306b'][::-1]4445# lists that indicate which colors to use depending on how many are used46which_colors = [[],47[1],48[1, 3],49[0, 2, 4],50[0, 2, 4, 6],51[0, 2, 3, 5, 6],52[0, 2, 3, 4, 5, 6],53[0, 1, 2, 3, 4, 5, 6],54[0, 1, 2, 3, 4, 5, 6, 7],55[0, 1, 2, 3, 4, 5, 6, 7, 8],56]5758current_figure = None5960@classmethod61def Colors(cls):62"""Returns the list of colors.63"""64return cls.colors6566@classmethod67def ColorGenerator(cls, num):68"""Returns an iterator of color strings.6970n: how many colors will be used71"""72for i in cls.which_colors[num]:73yield cls.colors[i]74raise StopIteration('Ran out of colors in _Brewer.')7576@classmethod77def InitIter(cls, num):78"""Initializes the color iterator with the given number of colors."""79cls.color_iter = cls.ColorGenerator(num)80fig = plt.gcf()81cls.current_figure = fig8283@classmethod84def ClearIter(cls):85"""Sets the color iterator to None."""86cls.color_iter = None87cls.current_figure = None8889@classmethod90def GetIter(cls, num):91"""Gets the color iterator."""92fig = plt.gcf()93if fig != cls.current_figure:94cls.InitIter(num)95cls.current_figure = fig9697if cls.color_iter is None:98cls.InitIter(num)99100return cls.color_iter101102103def _UnderrideColor(options):104"""If color is not in the options, chooses a color.105"""106if 'color' in options:107return options108109# get the current color iterator; if there is none, init one110color_iter = _Brewer.GetIter(5)111112try:113options['color'] = next(color_iter)114except StopIteration:115# if you run out of colors, initialize the color iterator116# and try again117warnings.warn('Ran out of colors. Starting over.')118_Brewer.ClearIter()119_UnderrideColor(options)120121return options122123124def PrePlot(num=None, rows=None, cols=None):125"""Takes hints about what's coming.126127num: number of lines that will be plotted128rows: number of rows of subplots129cols: number of columns of subplots130"""131if num:132_Brewer.InitIter(num)133134if rows is None and cols is None:135return136137if rows is not None and cols is None:138cols = 1139140if cols is not None and rows is None:141rows = 1142143# resize the image, depending on the number of rows and cols144size_map = {(1, 1): (8, 6),145(1, 2): (12, 6),146(1, 3): (12, 6),147(1, 4): (12, 5),148(1, 5): (12, 4),149(2, 2): (10, 10),150(2, 3): (16, 10),151(3, 1): (8, 10),152(4, 1): (8, 12),153}154155if (rows, cols) in size_map:156fig = plt.gcf()157fig.set_size_inches(*size_map[rows, cols])158159# create the first subplot160if rows > 1 or cols > 1:161ax = plt.subplot(rows, cols, 1)162global SUBPLOT_ROWS, SUBPLOT_COLS163SUBPLOT_ROWS = rows164SUBPLOT_COLS = cols165else:166ax = plt.gca()167168return ax169170def SubPlot(plot_number, rows=None, cols=None, **options):171"""Configures the number of subplots and changes the current plot.172173rows: int174cols: int175plot_number: int176options: passed to subplot177"""178rows = rows or SUBPLOT_ROWS179cols = cols or SUBPLOT_COLS180return plt.subplot(rows, cols, plot_number, **options)181182183def _Underride(d, **options):184"""Add key-value pairs to d only if key is not in d.185186If d is None, create a new dictionary.187188d: dictionary189options: keyword args to add to d190"""191if d is None:192d = {}193194for key, val in options.items():195d.setdefault(key, val)196197return d198199200def Clf():201"""Clears the figure and any hints that have been set."""202global LOC203LOC = None204_Brewer.ClearIter()205plt.clf()206fig = plt.gcf()207fig.set_size_inches(8, 6)208209210def Figure(**options):211"""Sets options for the current figure."""212_Underride(options, figsize=(6, 8))213plt.figure(**options)214215216def Plot(obj, ys=None, style='', **options):217"""Plots a line.218219Args:220obj: sequence of x values, or Series, or anything with Render()221ys: sequence of y values222style: style string passed along to plt.plot223options: keyword args passed to plt.plot224"""225options = _UnderrideColor(options)226label = getattr(obj, 'label', '_nolegend_')227options = _Underride(options, linewidth=3, alpha=0.7, label=label)228229xs = obj230if ys is None:231if hasattr(obj, 'Render'):232xs, ys = obj.Render()233if isinstance(obj, pandas.Series):234ys = obj.values235xs = obj.index236237if ys is None:238plt.plot(xs, style, **options)239else:240plt.plot(xs, ys, style, **options)241242243def Vlines(xs, y1, y2, **options):244"""Plots a set of vertical lines.245246Args:247xs: sequence of x values248y1: sequence of y values249y2: sequence of y values250options: keyword args passed to plt.vlines251"""252options = _UnderrideColor(options)253options = _Underride(options, linewidth=1, alpha=0.5)254plt.vlines(xs, y1, y2, **options)255256257def Hlines(ys, x1, x2, **options):258"""Plots a set of horizontal lines.259260Args:261ys: sequence of y values262x1: sequence of x values263x2: sequence of x values264options: keyword args passed to plt.vlines265"""266options = _UnderrideColor(options)267options = _Underride(options, linewidth=1, alpha=0.5)268plt.hlines(ys, x1, x2, **options)269270271def FillBetween(xs, y1, y2=None, where=None, **options):272"""Fills the space between two lines.273274Args:275xs: sequence of x values276y1: sequence of y values277y2: sequence of y values278where: sequence of boolean279options: keyword args passed to plt.fill_between280"""281options = _UnderrideColor(options)282options = _Underride(options, linewidth=0, alpha=0.5)283plt.fill_between(xs, y1, y2, where, **options)284285286def Bar(xs, ys, **options):287"""Plots a line.288289Args:290xs: sequence of x values291ys: sequence of y values292options: keyword args passed to plt.bar293"""294options = _UnderrideColor(options)295options = _Underride(options, linewidth=0, alpha=0.6)296plt.bar(xs, ys, **options)297298299def Scatter(xs, ys=None, **options):300"""Makes a scatter plot.301302xs: x values303ys: y values304options: options passed to plt.scatter305"""306options = _Underride(options, color='blue', alpha=0.2,307s=30, edgecolors='none')308309if ys is None and isinstance(xs, pandas.Series):310ys = xs.values311xs = xs.index312313plt.scatter(xs, ys, **options)314315316def HexBin(xs, ys, **options):317"""Makes a scatter plot.318319xs: x values320ys: y values321options: options passed to plt.scatter322"""323options = _Underride(options, cmap=matplotlib.cm.Blues)324plt.hexbin(xs, ys, **options)325326327def Pdf(pdf, **options):328"""Plots a Pdf, Pmf, or Hist as a line.329330Args:331pdf: Pdf, Pmf, or Hist object332options: keyword args passed to plt.plot333"""334low, high = options.pop('low', None), options.pop('high', None)335n = options.pop('n', 101)336xs, ps = pdf.Render(low=low, high=high, n=n)337options = _Underride(options, label=pdf.label)338Plot(xs, ps, **options)339340341def Pdfs(pdfs, **options):342"""Plots a sequence of PDFs.343344Options are passed along for all PDFs. If you want different345options for each pdf, make multiple calls to Pdf.346347Args:348pdfs: sequence of PDF objects349options: keyword args passed to plt.plot350"""351for pdf in pdfs:352Pdf(pdf, **options)353354355def Hist(hist, **options):356"""Plots a Pmf or Hist with a bar plot.357358The default width of the bars is based on the minimum difference359between values in the Hist. If that's too small, you can override360it by providing a width keyword argument, in the same units361as the values.362363Args:364hist: Hist or Pmf object365options: keyword args passed to plt.bar366"""367# find the minimum distance between adjacent values368xs, ys = hist.Render()369370# see if the values support arithmetic371try:372xs[0] - xs[0]373except TypeError:374# if not, replace values with numbers375labels = [str(x) for x in xs]376xs = np.arange(len(xs))377plt.xticks(xs+0.5, labels)378379if 'width' not in options:380try:381options['width'] = 0.9 * np.diff(xs).min()382except TypeError:383warnings.warn("Hist: Can't compute bar width automatically."384"Check for non-numeric types in Hist."385"Or try providing width option."386)387388options = _Underride(options, label=hist.label)389options = _Underride(options, align='center')390if options['align'] == 'left':391options['align'] = 'edge'392elif options['align'] == 'right':393options['align'] = 'edge'394options['width'] *= -1395396Bar(xs, ys, **options)397398399def Hists(hists, **options):400"""Plots two histograms as interleaved bar plots.401402Options are passed along for all PMFs. If you want different403options for each pmf, make multiple calls to Pmf.404405Args:406hists: list of two Hist or Pmf objects407options: keyword args passed to plt.plot408"""409for hist in hists:410Hist(hist, **options)411412413def Pmf(pmf, **options):414"""Plots a Pmf or Hist as a line.415416Args:417pmf: Hist or Pmf object418options: keyword args passed to plt.plot419"""420xs, ys = pmf.Render()421low, high = min(xs), max(xs)422423width = options.pop('width', None)424if width is None:425try:426width = np.diff(xs).min()427except TypeError:428warnings.warn("Pmf: Can't compute bar width automatically."429"Check for non-numeric types in Pmf."430"Or try providing width option.")431points = []432433lastx = np.nan434lasty = 0435for x, y in zip(xs, ys):436if (x - lastx) > 1e-5:437points.append((lastx, 0))438points.append((x, 0))439440points.append((x, lasty))441points.append((x, y))442points.append((x+width, y))443444lastx = x + width445lasty = y446points.append((lastx, 0))447pxs, pys = zip(*points)448449align = options.pop('align', 'center')450if align == 'center':451pxs = np.array(pxs) - width/2.0452if align == 'right':453pxs = np.array(pxs) - width454455options = _Underride(options, label=pmf.label)456Plot(pxs, pys, **options)457458459def Pmfs(pmfs, **options):460"""Plots a sequence of PMFs.461462Options are passed along for all PMFs. If you want different463options for each pmf, make multiple calls to Pmf.464465Args:466pmfs: sequence of PMF objects467options: keyword args passed to plt.plot468"""469for pmf in pmfs:470Pmf(pmf, **options)471472473def Diff(t):474"""Compute the differences between adjacent elements in a sequence.475476Args:477t: sequence of number478479Returns:480sequence of differences (length one less than t)481"""482diffs = [t[i+1] - t[i] for i in range(len(t)-1)]483return diffs484485486def Cdf(cdf, complement=False, transform=None, **options):487"""Plots a CDF as a line.488489Args:490cdf: Cdf object491complement: boolean, whether to plot the complementary CDF492transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'493options: keyword args passed to plt.plot494495Returns:496dictionary with the scale options that should be passed to497Config, Show or Save.498"""499xs, ps = cdf.Render()500xs = np.asarray(xs)501ps = np.asarray(ps)502503scale = dict(xscale='linear', yscale='linear')504505for s in ['xscale', 'yscale']:506if s in options:507scale[s] = options.pop(s)508509if transform == 'exponential':510complement = True511scale['yscale'] = 'log'512513if transform == 'pareto':514complement = True515scale['yscale'] = 'log'516scale['xscale'] = 'log'517518if complement:519ps = [1.0-p for p in ps]520521if transform == 'weibull':522xs = np.delete(xs, -1)523ps = np.delete(ps, -1)524ps = [-math.log(1.0-p) for p in ps]525scale['xscale'] = 'log'526scale['yscale'] = 'log'527528if transform == 'gumbel':529xs = xp.delete(xs, 0)530ps = np.delete(ps, 0)531ps = [-math.log(p) for p in ps]532scale['yscale'] = 'log'533534options = _Underride(options, label=cdf.label)535Plot(xs, ps, **options)536return scale537538539def Cdfs(cdfs, complement=False, transform=None, **options):540"""Plots a sequence of CDFs.541542cdfs: sequence of CDF objects543complement: boolean, whether to plot the complementary CDF544transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'545options: keyword args passed to plt.plot546"""547for cdf in cdfs:548Cdf(cdf, complement, transform, **options)549550551def Contour(obj, pcolor=False, contour=True, imshow=False, **options):552"""Makes a contour plot.553554d: map from (x, y) to z, or object that provides GetDict555pcolor: boolean, whether to make a pseudocolor plot556contour: boolean, whether to make a contour plot557imshow: boolean, whether to use plt.imshow558options: keyword args passed to plt.pcolor and/or plt.contour559"""560try:561d = obj.GetDict()562except AttributeError:563d = obj564565_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)566567xs, ys = zip(*d.keys())568xs = sorted(set(xs))569ys = sorted(set(ys))570571X, Y = np.meshgrid(xs, ys)572func = lambda x, y: d.get((x, y), 0)573func = np.vectorize(func)574Z = func(X, Y)575576x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)577axes = plt.gca()578axes.xaxis.set_major_formatter(x_formatter)579580if pcolor:581plt.pcolormesh(X, Y, Z, **options)582if contour:583cs = plt.contour(X, Y, Z, **options)584plt.clabel(cs, inline=1, fontsize=10)585if imshow:586extent = xs[0], xs[-1], ys[0], ys[-1]587plt.imshow(Z, extent=extent, **options)588589590def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):591"""Makes a pseudocolor plot.592593xs:594ys:595zs:596pcolor: boolean, whether to make a pseudocolor plot597contour: boolean, whether to make a contour plot598options: keyword args passed to plt.pcolor and/or plt.contour599"""600_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)601602X, Y = np.meshgrid(xs, ys)603Z = zs604605x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)606axes = plt.gca()607axes.xaxis.set_major_formatter(x_formatter)608609if pcolor:610plt.pcolormesh(X, Y, Z, **options)611612if contour:613cs = plt.contour(X, Y, Z, **options)614plt.clabel(cs, inline=1, fontsize=10)615616617def Text(x, y, s, **options):618"""Puts text in a figure.619620x: number621y: number622s: string623options: keyword args passed to plt.text624"""625options = _Underride(options,626fontsize=16,627verticalalignment='top',628horizontalalignment='left')629plt.text(x, y, s, **options)630631632LEGEND = True633LOC = None634635def Config(**options):636"""Configures the plot.637638Pulls options out of the option dictionary and passes them to639the corresponding plt functions.640"""641names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',642'xticks', 'yticks', 'axis', 'xlim', 'ylim']643644for name in names:645if name in options:646getattr(plt, name)(options[name])647648global LEGEND649LEGEND = options.get('legend', LEGEND)650651if LEGEND:652global LOC653LOC = options.get('loc', LOC)654frameon = options.get('frameon', True)655656warnings.filterwarnings('error', category=UserWarning)657try:658plt.legend(loc=LOC, frameon=frameon)659except UserWarning:660pass661warnings.filterwarnings('default', category=UserWarning)662663# x and y ticklabels can be made invisible664val = options.get('xticklabels', None)665if val is not None:666if val == 'invisible':667ax = plt.gca()668labels = ax.get_xticklabels()669plt.setp(labels, visible=False)670671val = options.get('yticklabels', None)672if val is not None:673if val == 'invisible':674ax = plt.gca()675labels = ax.get_yticklabels()676plt.setp(labels, visible=False)677678679def Show(**options):680"""Shows the plot.681682For options, see Config.683684options: keyword args used to invoke various plt functions685"""686clf = options.pop('clf', True)687Config(**options)688plt.show()689if clf:690Clf()691692693def Plotly(**options):694"""Shows the plot.695696For options, see Config.697698options: keyword args used to invoke various plt functions699"""700clf = options.pop('clf', True)701Config(**options)702import plotly.plotly as plotly703url = plotly.plot_mpl(plt.gcf())704if clf:705Clf()706return url707708709def Save(root=None, formats=None, **options):710"""Saves the plot in the given formats and clears the figure.711712For options, see Config.713714Args:715root: string filename root716formats: list of string formats717options: keyword args used to invoke various plt functions718"""719clf = options.pop('clf', True)720721save_options = {}722for option in ['bbox_inches', 'pad_inches']:723if option in options:724save_options[option] = options.pop(option)725726Config(**options)727728if formats is None:729formats = ['pdf', 'eps']730731try:732formats.remove('plotly')733Plotly(clf=False)734except ValueError:735pass736737if root:738for fmt in formats:739SaveFormat(root, fmt, **save_options)740if clf:741Clf()742743744def SaveFormat(root, fmt='eps', **options):745"""Writes the current figure to a file in the given format.746747Args:748root: string filename root749fmt: string format750"""751_Underride(options, dpi=300)752filename = '%s.%s' % (root, fmt)753print('Writing', filename)754plt.savefig(filename, format=fmt, **options)755756757# provide aliases for calling functions with lower-case names758preplot = PrePlot759subplot = SubPlot760clf = Clf761figure = Figure762plot = Plot763vlines = Vlines764hlines = Hlines765fill_between = FillBetween766text = Text767scatter = Scatter768pmf = Pmf769pmfs = Pmfs770hist = Hist771hists = Hists772diff = Diff773cdf = Cdf774cdfs = Cdfs775contour = Contour776pcolor = Pcolor777config = Config778show = Show779save = Save780781782def main():783color_iter = _Brewer.ColorGenerator(7)784for color in color_iter:785print(color)786787788if __name__ == '__main__':789main()790791792