📚 The CoCalc Library - books, templates and other resources
License: OTHER
"""This file contains code for use with "Think Stats",1by Allen B. Downey, available from greenteapress.com23Copyright 2010 Allen B. Downey4License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html5"""67import math8import matplotlib9import matplotlib.pyplot as pyplot10import numpy as np1112# customize some matplotlib attributes13#matplotlib.rc('figure', figsize=(4, 3))1415#matplotlib.rc('font', size=14.0)16#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)17#matplotlib.rc('legend', fontsize=20.0)1819#matplotlib.rc('xtick.major', size=6.0)20#matplotlib.rc('xtick.minor', size=3.0)2122#matplotlib.rc('ytick.major', size=6.0)23#matplotlib.rc('ytick.minor', size=3.0)242526class Brewer(object):27"""Encapsulates a nice sequence of colors.2829Shades of blue that look good in color and can be distinguished30in grayscale (up to a point).3132Borrowed from http://colorbrewer2.org/33"""34color_iter = None3536colors = ['#081D58',37'#253494',38'#225EA8',39'#1D91C0',40'#41B6C4',41'#7FCDBB',42'#C7E9B4',43'#EDF8B1',44'#FFFFD9']4546# lists that indicate which colors to use depending on how many are used47which_colors = [[],48[1],49[1, 3],50[0, 2, 4],51[0, 2, 4, 6],52[0, 2, 3, 5, 6],53[0, 2, 3, 4, 5, 6],54[0, 1, 2, 3, 4, 5, 6],55]5657@classmethod58def Colors(cls):59"""Returns the list of colors.60"""61return cls.colors6263@classmethod64def ColorGenerator(cls, n):65"""Returns an iterator of color strings.6667n: how many colors will be used68"""69for i in cls.which_colors[n]:70yield cls.colors[i]71raise StopIteration('Ran out of colors in Brewer.ColorGenerator')7273@classmethod74def InitializeIter(cls, num):75"""Initializes the color iterator with the given number of colors."""76cls.color_iter = cls.ColorGenerator(num)7778@classmethod79def ClearIter(cls):80"""Sets the color iterator to None."""81cls.color_iter = None8283@classmethod84def GetIter(cls):85"""Gets the color iterator."""86return cls.color_iter878889def PrePlot(num=None, rows=1, cols=1):90"""Takes hints about what's coming.9192num: number of lines that will be plotted93"""94if num:95Brewer.InitializeIter(num)9697# TODO: get sharey and sharex working. probably means switching98# to subplots instead of subplot.99# also, get rid of the gray background.100101if rows > 1 or cols > 1:102pyplot.subplots(rows, cols, sharey=True)103global SUBPLOT_ROWS, SUBPLOT_COLS104SUBPLOT_ROWS = rows105SUBPLOT_COLS = cols106107108def SubPlot(plot_number):109pyplot.subplot(SUBPLOT_ROWS, SUBPLOT_COLS, plot_number)110111112class InfiniteList(list):113"""A list that returns the same value for all indices."""114def __init__(self, val):115"""Initializes the list.116117val: value to be stored118"""119list.__init__(self)120self.val = val121122def __getitem__(self, index):123"""Gets the item with the given index.124125index: int126127returns: the stored value128"""129return self.val130131132def Underride(d, **options):133"""Add key-value pairs to d only if key is not in d.134135If d is None, create a new dictionary.136137d: dictionary138options: keyword args to add to d139"""140if d is None:141d = {}142143for key, val in options.iteritems():144d.setdefault(key, val)145146return d147148149def Clf():150"""Clears the figure and any hints that have been set."""151Brewer.ClearIter()152pyplot.clf()153154155def Figure(**options):156"""Sets options for the current figure."""157Underride(options, figsize=(6, 8))158pyplot.figure(**options)159160161def Plot(xs, ys, style='', **options):162"""Plots a line.163164Args:165xs: sequence of x values166ys: sequence of y values167style: style string passed along to pyplot.plot168options: keyword args passed to pyplot.plot169"""170color_iter = Brewer.GetIter()171172if color_iter:173try:174options = Underride(options, color=color_iter.next())175except StopIteration:176print 'Warning: Brewer ran out of colors.'177Brewer.ClearIter()178179options = Underride(options, linewidth=3, alpha=0.8)180pyplot.plot(xs, ys, style, **options)181182183def Scatter(xs, ys, **options):184"""Makes a scatter plot.185186xs: x values187ys: y values188options: options passed to pyplot.scatter189"""190options = Underride(options, color='blue', alpha=0.2,191s=30, edgecolors='none')192pyplot.scatter(xs, ys, **options)193194195def Pmf(pmf, **options):196"""Plots a Pmf or Hist as a line.197198Args:199pmf: Hist or Pmf object200options: keyword args passed to pyplot.plot201"""202xs, ps = pmf.Render()203if pmf.name:204options = Underride(options, label=pmf.name)205Plot(xs, ps, **options)206207208def Pmfs(pmfs, **options):209"""Plots a sequence of PMFs.210211Options are passed along for all PMFs. If you want different212options for each pmf, make multiple calls to Pmf.213214Args:215pmfs: sequence of PMF objects216options: keyword args passed to pyplot.plot217"""218for pmf in pmfs:219Pmf(pmf, **options)220221222def Hist(hist, **options):223"""Plots a Pmf or Hist with a bar plot.224225Args:226hist: Hist or Pmf object227options: keyword args passed to pyplot.bar228"""229# find the minimum distance between adjacent values230xs, fs = hist.Render()231width = min(Diff(xs))232233if hist.name:234options = Underride(options, label=hist.name)235236options = Underride(options,237align='center',238linewidth=0,239width=width)240241pyplot.bar(xs, fs, **options)242243244def Hists(hists, **options):245"""Plots two histograms as interleaved bar plots.246247Options are passed along for all PMFs. If you want different248options for each pmf, make multiple calls to Pmf.249250Args:251hists: list of two Hist or Pmf objects252options: keyword args passed to pyplot.plot253"""254for hist in hists:255Hist(hist, **options)256257258def Diff(t):259"""Compute the differences between adjacent elements in a sequence.260261Args:262t: sequence of number263264Returns:265sequence of differences (length one less than t)266"""267diffs = [t[i+1] - t[i] for i in range(len(t)-1)]268return diffs269270271def Cdf(cdf, complement=False, transform=None, **options):272"""Plots a CDF as a line.273274Args:275cdf: Cdf object276complement: boolean, whether to plot the complementary CDF277transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'278options: keyword args passed to pyplot.plot279280Returns:281dictionary with the scale options that should be passed to282myplot.Save or myplot.Show283"""284xs, ps = cdf.Render()285scale = dict(xscale='linear', yscale='linear')286287if transform == 'exponential':288complement = True289scale['yscale'] = 'log'290291if transform == 'pareto':292complement = True293scale['yscale'] = 'log'294scale['xscale'] = 'log'295296if complement:297ps = [1.0-p for p in ps]298299if transform == 'weibull':300xs.pop()301ps.pop()302ps = [-math.log(1.0-p) for p in ps]303scale['xscale'] = 'log'304scale['yscale'] = 'log'305306if transform == 'gumbel':307xs.pop(0)308ps.pop(0)309ps = [-math.log(p) for p in ps]310scale['yscale'] = 'log'311312if cdf.name:313options = Underride(options, label=cdf.name)314315Plot(xs, ps, **options)316return scale317318319def Cdfs(cdfs, complement=False, transform=None, **options):320"""Plots a sequence of CDFs.321322cdfs: sequence of CDF objects323complement: boolean, whether to plot the complementary CDF324transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'325options: keyword args passed to pyplot.plot326"""327for cdf in cdfs:328Cdf(cdf, complement, transform, **options)329330331def Contour(obj, pcolor=False, contour=True, imshow=False, **options):332"""Makes a contour plot.333334d: map from (x, y) to z, or object that provides GetDict335pcolor: boolean, whether to make a pseudocolor plot336contour: boolean, whether to make a contour plot337imshow: boolean, whether to use pyplot.imshow338options: keyword args passed to pyplot.pcolor and/or pyplot.contour339"""340try:341d = obj.GetDict()342except AttributeError:343d = obj344345Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)346347xs, ys = zip(*d.iterkeys())348xs = sorted(set(xs))349ys = sorted(set(ys))350351X, Y = np.meshgrid(xs, ys)352func = lambda x, y: d.get((x, y), 0)353func = np.vectorize(func)354Z = func(X, Y)355356x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)357axes = pyplot.gca()358axes.xaxis.set_major_formatter(x_formatter)359360if pcolor:361pyplot.pcolormesh(X, Y, Z, **options)362if contour:363cs = pyplot.contour(X, Y, Z, **options)364pyplot.clabel(cs, inline=1, fontsize=10)365if imshow:366extent = xs[0], xs[-1], ys[0], ys[-1]367pyplot.imshow(Z, extent=extent, **options)368369370def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):371"""Makes a pseudocolor plot.372373xs:374ys:375zs:376pcolor: boolean, whether to make a pseudocolor plot377contour: boolean, whether to make a contour plot378options: keyword args passed to pyplot.pcolor and/or pyplot.contour379"""380Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)381382X, Y = np.meshgrid(xs, ys)383Z = zs384385x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)386axes = pyplot.gca()387axes.xaxis.set_major_formatter(x_formatter)388389if pcolor:390pyplot.pcolormesh(X, Y, Z, **options)391392if contour:393cs = pyplot.contour(X, Y, Z, **options)394pyplot.clabel(cs, inline=1, fontsize=10)395396397def Config(**options):398"""Configures the plot.399400Pulls options out of the option dictionary and passes them to401title, xlabel, ylabel, xscale, yscale, xticks, yticks, axis, legend,402and loc.403"""404title = options.get('title', '')405pyplot.title(title)406407xlabel = options.get('xlabel', '')408pyplot.xlabel(xlabel)409410ylabel = options.get('ylabel', '')411pyplot.ylabel(ylabel)412413if 'xscale' in options:414pyplot.xscale(options['xscale'])415416if 'xticks' in options:417pyplot.xticks(options['xticks'])418419if 'yscale' in options:420pyplot.yscale(options['yscale'])421422if 'yticks' in options:423pyplot.yticks(options['yticks'])424425if 'axis' in options:426pyplot.axis(options['axis'])427428loc = options.get('loc', 0)429legend = options.get('legend', True)430if legend:431pyplot.legend(loc=loc)432433434def Show(**options):435"""Shows the plot.436437For options, see Config.438439options: keyword args used to invoke various pyplot functions440"""441# TODO: figure out how to show more than one plot442Config(**options)443pyplot.show()444445446def Save(root=None, formats=None, **options):447"""Saves the plot in the given formats.448449For options, see Config.450451Args:452root: string filename root453formats: list of string formats454options: keyword args used to invoke various pyplot functions455"""456Config(**options)457458if formats is None:459formats = ['pdf', 'eps']460461if root:462for fmt in formats:463SaveFormat(root, fmt)464Clf()465466467def SaveFormat(root, fmt='eps'):468"""Writes the current figure to a file in the given format.469470Args:471root: string filename root472fmt: string format473"""474filename = '%s.%s' % (root, fmt)475print 'Writing', filename476pyplot.savefig(filename, format=fmt, dpi=300)477478479# provide aliases for calling functons with lower-case names480preplot = PrePlot481subplot = SubPlot482clf = Clf483figure = Figure484plot = Plot485scatter = Scatter486pmf = Pmf487pmfs = Pmfs488hist = Hist489hists = Hists490diff = Diff491cdf = Cdf492cdfs = Cdfs493contour = Contour494pcolor = Pcolor495config = Config496show = Show497save = Save498499500def main():501color_iter = Brewer.ColorGenerator(7)502for color in color_iter:503print color504505if __name__ == '__main__':506main()507508509