Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

Repository for a workshop on Bayesian statistics

1430 views
1
"""This file contains code for use with "Think Stats",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2014 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
from __future__ import print_function
9
10
import math
11
import matplotlib
12
import matplotlib.pyplot as plt
13
import numpy as np
14
import pandas
15
16
import warnings
17
18
# customize some matplotlib attributes
19
#matplotlib.rc('figure', figsize=(4, 3))
20
21
#matplotlib.rc('font', size=14.0)
22
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
23
#matplotlib.rc('legend', fontsize=20.0)
24
25
#matplotlib.rc('xtick.major', size=6.0)
26
#matplotlib.rc('xtick.minor', size=3.0)
27
28
#matplotlib.rc('ytick.major', size=6.0)
29
#matplotlib.rc('ytick.minor', size=3.0)
30
31
32
class _Brewer(object):
33
"""Encapsulates a nice sequence of colors.
34
35
Shades of blue that look good in color and can be distinguished
36
in grayscale (up to a point).
37
38
Borrowed from http://colorbrewer2.org/
39
"""
40
color_iter = None
41
42
colors = ['#f7fbff', '#deebf7', '#c6dbef',
43
'#9ecae1', '#6baed6', '#4292c6',
44
'#2171b5','#08519c','#08306b'][::-1]
45
46
# lists that indicate which colors to use depending on how many are used
47
which_colors = [[],
48
[1],
49
[1, 3],
50
[0, 2, 4],
51
[0, 2, 4, 6],
52
[0, 2, 3, 5, 6],
53
[0, 2, 3, 4, 5, 6],
54
[0, 1, 2, 3, 4, 5, 6],
55
[0, 1, 2, 3, 4, 5, 6, 7],
56
[0, 1, 2, 3, 4, 5, 6, 7, 8],
57
]
58
59
current_figure = None
60
61
@classmethod
62
def Colors(cls):
63
"""Returns the list of colors.
64
"""
65
return cls.colors
66
67
@classmethod
68
def ColorGenerator(cls, num):
69
"""Returns an iterator of color strings.
70
71
n: how many colors will be used
72
"""
73
for i in cls.which_colors[num]:
74
yield cls.colors[i]
75
raise StopIteration('Ran out of colors in _Brewer.')
76
77
@classmethod
78
def InitIter(cls, num):
79
"""Initializes the color iterator with the given number of colors."""
80
cls.color_iter = cls.ColorGenerator(num)
81
fig = plt.gcf()
82
cls.current_figure = fig
83
84
@classmethod
85
def ClearIter(cls):
86
"""Sets the color iterator to None."""
87
cls.color_iter = None
88
cls.current_figure = None
89
90
@classmethod
91
def GetIter(cls, num):
92
"""Gets the color iterator."""
93
fig = plt.gcf()
94
if fig != cls.current_figure:
95
cls.InitIter(num)
96
cls.current_figure = fig
97
98
if cls.color_iter is None:
99
cls.InitIter(num)
100
101
return cls.color_iter
102
103
104
def _UnderrideColor(options):
105
"""If color is not in the options, chooses a color.
106
"""
107
if 'color' in options:
108
return options
109
110
# get the current color iterator; if there is none, init one
111
color_iter = _Brewer.GetIter(5)
112
113
try:
114
options['color'] = next(color_iter)
115
except StopIteration:
116
# if you run out of colors, initialize the color iterator
117
# and try again
118
warnings.warn('Ran out of colors. Starting over.')
119
_Brewer.ClearIter()
120
_UnderrideColor(options)
121
122
return options
123
124
125
def PrePlot(num=None, rows=None, cols=None):
126
"""Takes hints about what's coming.
127
128
num: number of lines that will be plotted
129
rows: number of rows of subplots
130
cols: number of columns of subplots
131
"""
132
if num:
133
_Brewer.InitIter(num)
134
135
if rows is None and cols is None:
136
return
137
138
if rows is not None and cols is None:
139
cols = 1
140
141
if cols is not None and rows is None:
142
rows = 1
143
144
# resize the image, depending on the number of rows and cols
145
size_map = {(1, 1): (8, 6),
146
(1, 2): (12, 6),
147
(1, 3): (12, 6),
148
(1, 4): (12, 5),
149
(1, 5): (12, 4),
150
(2, 2): (10, 10),
151
(2, 3): (16, 10),
152
(3, 1): (8, 10),
153
(4, 1): (8, 12),
154
}
155
156
if (rows, cols) in size_map:
157
fig = plt.gcf()
158
fig.set_size_inches(*size_map[rows, cols])
159
160
# create the first subplot
161
if rows > 1 or cols > 1:
162
ax = plt.subplot(rows, cols, 1)
163
global SUBPLOT_ROWS, SUBPLOT_COLS
164
SUBPLOT_ROWS = rows
165
SUBPLOT_COLS = cols
166
else:
167
ax = plt.gca()
168
169
return ax
170
171
def SubPlot(plot_number, rows=None, cols=None, **options):
172
"""Configures the number of subplots and changes the current plot.
173
174
rows: int
175
cols: int
176
plot_number: int
177
options: passed to subplot
178
"""
179
rows = rows or SUBPLOT_ROWS
180
cols = cols or SUBPLOT_COLS
181
return plt.subplot(rows, cols, plot_number, **options)
182
183
184
def _Underride(d, **options):
185
"""Add key-value pairs to d only if key is not in d.
186
187
If d is None, create a new dictionary.
188
189
d: dictionary
190
options: keyword args to add to d
191
"""
192
if d is None:
193
d = {}
194
195
for key, val in options.items():
196
d.setdefault(key, val)
197
198
return d
199
200
201
def Clf():
202
"""Clears the figure and any hints that have been set."""
203
global LOC
204
LOC = None
205
_Brewer.ClearIter()
206
plt.clf()
207
fig = plt.gcf()
208
fig.set_size_inches(8, 6)
209
210
211
def Figure(**options):
212
"""Sets options for the current figure."""
213
_Underride(options, figsize=(6, 8))
214
plt.figure(**options)
215
216
217
def Plot(obj, ys=None, style='', **options):
218
"""Plots a line.
219
220
Args:
221
obj: sequence of x values, or Series, or anything with Render()
222
ys: sequence of y values
223
style: style string passed along to plt.plot
224
options: keyword args passed to plt.plot
225
"""
226
options = _UnderrideColor(options)
227
label = getattr(obj, 'label', '_nolegend_')
228
options = _Underride(options, linewidth=3, alpha=0.7, label=label)
229
230
xs = obj
231
if ys is None:
232
if hasattr(obj, 'Render'):
233
xs, ys = obj.Render()
234
if isinstance(obj, pandas.Series):
235
ys = obj.values
236
xs = obj.index
237
238
if ys is None:
239
plt.plot(xs, style, **options)
240
else:
241
plt.plot(xs, ys, style, **options)
242
243
244
def Vlines(xs, y1, y2, **options):
245
"""Plots a set of vertical lines.
246
247
Args:
248
xs: sequence of x values
249
y1: sequence of y values
250
y2: sequence of y values
251
options: keyword args passed to plt.vlines
252
"""
253
options = _UnderrideColor(options)
254
options = _Underride(options, linewidth=1, alpha=0.5)
255
plt.vlines(xs, y1, y2, **options)
256
257
258
def Hlines(ys, x1, x2, **options):
259
"""Plots a set of horizontal lines.
260
261
Args:
262
ys: sequence of y values
263
x1: sequence of x values
264
x2: sequence of x values
265
options: keyword args passed to plt.vlines
266
"""
267
options = _UnderrideColor(options)
268
options = _Underride(options, linewidth=1, alpha=0.5)
269
plt.hlines(ys, x1, x2, **options)
270
271
272
def FillBetween(xs, y1, y2=None, where=None, **options):
273
"""Fills the space between two lines.
274
275
Args:
276
xs: sequence of x values
277
y1: sequence of y values
278
y2: sequence of y values
279
where: sequence of boolean
280
options: keyword args passed to plt.fill_between
281
"""
282
options = _UnderrideColor(options)
283
options = _Underride(options, linewidth=0, alpha=0.5)
284
plt.fill_between(xs, y1, y2, where, **options)
285
286
287
def Bar(xs, ys, **options):
288
"""Plots a line.
289
290
Args:
291
xs: sequence of x values
292
ys: sequence of y values
293
options: keyword args passed to plt.bar
294
"""
295
options = _UnderrideColor(options)
296
options = _Underride(options, linewidth=0, alpha=0.6)
297
plt.bar(xs, ys, **options)
298
299
300
def Scatter(xs, ys=None, **options):
301
"""Makes a scatter plot.
302
303
xs: x values
304
ys: y values
305
options: options passed to plt.scatter
306
"""
307
options = _Underride(options, color='blue', alpha=0.2,
308
s=30, edgecolors='none')
309
310
if ys is None and isinstance(xs, pandas.Series):
311
ys = xs.values
312
xs = xs.index
313
314
plt.scatter(xs, ys, **options)
315
316
317
def HexBin(xs, ys, **options):
318
"""Makes a scatter plot.
319
320
xs: x values
321
ys: y values
322
options: options passed to plt.scatter
323
"""
324
options = _Underride(options, cmap=matplotlib.cm.Blues)
325
plt.hexbin(xs, ys, **options)
326
327
328
def Pdf(pdf, **options):
329
"""Plots a Pdf, Pmf, or Hist as a line.
330
331
Args:
332
pdf: Pdf, Pmf, or Hist object
333
options: keyword args passed to plt.plot
334
"""
335
low, high = options.pop('low', None), options.pop('high', None)
336
n = options.pop('n', 101)
337
xs, ps = pdf.Render(low=low, high=high, n=n)
338
options = _Underride(options, label=pdf.label)
339
Plot(xs, ps, **options)
340
341
342
def Pdfs(pdfs, **options):
343
"""Plots a sequence of PDFs.
344
345
Options are passed along for all PDFs. If you want different
346
options for each pdf, make multiple calls to Pdf.
347
348
Args:
349
pdfs: sequence of PDF objects
350
options: keyword args passed to plt.plot
351
"""
352
for pdf in pdfs:
353
Pdf(pdf, **options)
354
355
356
def Hist(hist, **options):
357
"""Plots a Pmf or Hist with a bar plot.
358
359
The default width of the bars is based on the minimum difference
360
between values in the Hist. If that's too small, you can override
361
it by providing a width keyword argument, in the same units
362
as the values.
363
364
Args:
365
hist: Hist or Pmf object
366
options: keyword args passed to plt.bar
367
"""
368
# find the minimum distance between adjacent values
369
xs, ys = hist.Render()
370
371
# see if the values support arithmetic
372
try:
373
xs[0] - xs[0]
374
except TypeError:
375
# if not, replace values with numbers
376
labels = [str(x) for x in xs]
377
xs = np.arange(len(xs))
378
plt.xticks(xs+0.5, labels)
379
380
if 'width' not in options:
381
try:
382
options['width'] = 0.9 * np.diff(xs).min()
383
except TypeError:
384
warnings.warn("Hist: Can't compute bar width automatically."
385
"Check for non-numeric types in Hist."
386
"Or try providing width option."
387
)
388
389
options = _Underride(options, label=hist.label)
390
options = _Underride(options, align='center')
391
if options['align'] == 'left':
392
options['align'] = 'edge'
393
elif options['align'] == 'right':
394
options['align'] = 'edge'
395
options['width'] *= -1
396
397
Bar(xs, ys, **options)
398
399
400
def Hists(hists, **options):
401
"""Plots two histograms as interleaved bar plots.
402
403
Options are passed along for all PMFs. If you want different
404
options for each pmf, make multiple calls to Pmf.
405
406
Args:
407
hists: list of two Hist or Pmf objects
408
options: keyword args passed to plt.plot
409
"""
410
for hist in hists:
411
Hist(hist, **options)
412
413
414
def Pmf(pmf, **options):
415
"""Plots a Pmf or Hist as a line.
416
417
Args:
418
pmf: Hist or Pmf object
419
options: keyword args passed to plt.plot
420
"""
421
xs, ys = pmf.Render()
422
low, high = min(xs), max(xs)
423
424
width = options.pop('width', None)
425
if width is None:
426
try:
427
width = np.diff(xs).min()
428
except TypeError:
429
warnings.warn("Pmf: Can't compute bar width automatically."
430
"Check for non-numeric types in Pmf."
431
"Or try providing width option.")
432
points = []
433
434
lastx = np.nan
435
lasty = 0
436
for x, y in zip(xs, ys):
437
if (x - lastx) > 1e-5:
438
points.append((lastx, 0))
439
points.append((x, 0))
440
441
points.append((x, lasty))
442
points.append((x, y))
443
points.append((x+width, y))
444
445
lastx = x + width
446
lasty = y
447
points.append((lastx, 0))
448
pxs, pys = zip(*points)
449
450
align = options.pop('align', 'center')
451
if align == 'center':
452
pxs = np.array(pxs) - width/2.0
453
if align == 'right':
454
pxs = np.array(pxs) - width
455
456
options = _Underride(options, label=pmf.label)
457
Plot(pxs, pys, **options)
458
459
460
def Pmfs(pmfs, **options):
461
"""Plots a sequence of PMFs.
462
463
Options are passed along for all PMFs. If you want different
464
options for each pmf, make multiple calls to Pmf.
465
466
Args:
467
pmfs: sequence of PMF objects
468
options: keyword args passed to plt.plot
469
"""
470
for pmf in pmfs:
471
Pmf(pmf, **options)
472
473
474
def Diff(t):
475
"""Compute the differences between adjacent elements in a sequence.
476
477
Args:
478
t: sequence of number
479
480
Returns:
481
sequence of differences (length one less than t)
482
"""
483
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
484
return diffs
485
486
487
def Cdf(cdf, complement=False, transform=None, **options):
488
"""Plots a CDF as a line.
489
490
Args:
491
cdf: Cdf object
492
complement: boolean, whether to plot the complementary CDF
493
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
494
options: keyword args passed to plt.plot
495
496
Returns:
497
dictionary with the scale options that should be passed to
498
Config, Show or Save.
499
"""
500
xs, ps = cdf.Render()
501
xs = np.asarray(xs)
502
ps = np.asarray(ps)
503
504
scale = dict(xscale='linear', yscale='linear')
505
506
for s in ['xscale', 'yscale']:
507
if s in options:
508
scale[s] = options.pop(s)
509
510
if transform == 'exponential':
511
complement = True
512
scale['yscale'] = 'log'
513
514
if transform == 'pareto':
515
complement = True
516
scale['yscale'] = 'log'
517
scale['xscale'] = 'log'
518
519
if complement:
520
ps = [1.0-p for p in ps]
521
522
if transform == 'weibull':
523
xs = np.delete(xs, -1)
524
ps = np.delete(ps, -1)
525
ps = [-math.log(1.0-p) for p in ps]
526
scale['xscale'] = 'log'
527
scale['yscale'] = 'log'
528
529
if transform == 'gumbel':
530
xs = xp.delete(xs, 0)
531
ps = np.delete(ps, 0)
532
ps = [-math.log(p) for p in ps]
533
scale['yscale'] = 'log'
534
535
options = _Underride(options, label=cdf.label)
536
Plot(xs, ps, **options)
537
return scale
538
539
540
def Cdfs(cdfs, complement=False, transform=None, **options):
541
"""Plots a sequence of CDFs.
542
543
cdfs: sequence of CDF objects
544
complement: boolean, whether to plot the complementary CDF
545
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
546
options: keyword args passed to plt.plot
547
"""
548
for cdf in cdfs:
549
Cdf(cdf, complement, transform, **options)
550
551
552
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
553
"""Makes a contour plot.
554
555
d: map from (x, y) to z, or object that provides GetDict
556
pcolor: boolean, whether to make a pseudocolor plot
557
contour: boolean, whether to make a contour plot
558
imshow: boolean, whether to use plt.imshow
559
options: keyword args passed to plt.pcolor and/or plt.contour
560
"""
561
try:
562
d = obj.GetDict()
563
except AttributeError:
564
d = obj
565
566
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
567
568
xs, ys = zip(*d.keys())
569
xs = sorted(set(xs))
570
ys = sorted(set(ys))
571
572
X, Y = np.meshgrid(xs, ys)
573
func = lambda x, y: d.get((x, y), 0)
574
func = np.vectorize(func)
575
Z = func(X, Y)
576
577
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
578
axes = plt.gca()
579
axes.xaxis.set_major_formatter(x_formatter)
580
581
if pcolor:
582
plt.pcolormesh(X, Y, Z, **options)
583
if contour:
584
cs = plt.contour(X, Y, Z, **options)
585
plt.clabel(cs, inline=1, fontsize=10)
586
if imshow:
587
extent = xs[0], xs[-1], ys[0], ys[-1]
588
plt.imshow(Z, extent=extent, **options)
589
590
591
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
592
"""Makes a pseudocolor plot.
593
594
xs:
595
ys:
596
zs:
597
pcolor: boolean, whether to make a pseudocolor plot
598
contour: boolean, whether to make a contour plot
599
options: keyword args passed to plt.pcolor and/or plt.contour
600
"""
601
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
602
603
X, Y = np.meshgrid(xs, ys)
604
Z = zs
605
606
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
607
axes = plt.gca()
608
axes.xaxis.set_major_formatter(x_formatter)
609
610
if pcolor:
611
plt.pcolormesh(X, Y, Z, **options)
612
613
if contour:
614
cs = plt.contour(X, Y, Z, **options)
615
plt.clabel(cs, inline=1, fontsize=10)
616
617
618
def Text(x, y, s, **options):
619
"""Puts text in a figure.
620
621
x: number
622
y: number
623
s: string
624
options: keyword args passed to plt.text
625
"""
626
options = _Underride(options,
627
fontsize=16,
628
verticalalignment='top',
629
horizontalalignment='left')
630
plt.text(x, y, s, **options)
631
632
633
LEGEND = True
634
LOC = None
635
636
def Config(**options):
637
"""Configures the plot.
638
639
Pulls options out of the option dictionary and passes them to
640
the corresponding plt functions.
641
"""
642
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
643
'xticks', 'yticks', 'axis', 'xlim', 'ylim']
644
645
for name in names:
646
if name in options:
647
getattr(plt, name)(options[name])
648
649
global LEGEND
650
LEGEND = options.get('legend', LEGEND)
651
652
if LEGEND:
653
global LOC
654
LOC = options.get('loc', LOC)
655
frameon = options.get('frameon', True)
656
657
warnings.filterwarnings('error', category=UserWarning)
658
try:
659
plt.legend(loc=LOC, frameon=frameon)
660
except UserWarning:
661
pass
662
warnings.filterwarnings('default', category=UserWarning)
663
664
# x and y ticklabels can be made invisible
665
val = options.get('xticklabels', None)
666
if val is not None:
667
if val == 'invisible':
668
ax = plt.gca()
669
labels = ax.get_xticklabels()
670
plt.setp(labels, visible=False)
671
672
val = options.get('yticklabels', None)
673
if val is not None:
674
if val == 'invisible':
675
ax = plt.gca()
676
labels = ax.get_yticklabels()
677
plt.setp(labels, visible=False)
678
679
680
def Show(**options):
681
"""Shows the plot.
682
683
For options, see Config.
684
685
options: keyword args used to invoke various plt functions
686
"""
687
clf = options.pop('clf', True)
688
Config(**options)
689
plt.show()
690
if clf:
691
Clf()
692
693
694
def Plotly(**options):
695
"""Shows the plot.
696
697
For options, see Config.
698
699
options: keyword args used to invoke various plt functions
700
"""
701
clf = options.pop('clf', True)
702
Config(**options)
703
import plotly.plotly as plotly
704
url = plotly.plot_mpl(plt.gcf())
705
if clf:
706
Clf()
707
return url
708
709
710
def Save(root=None, formats=None, **options):
711
"""Saves the plot in the given formats and clears the figure.
712
713
For options, see Config.
714
715
Args:
716
root: string filename root
717
formats: list of string formats
718
options: keyword args used to invoke various plt functions
719
"""
720
clf = options.pop('clf', True)
721
722
save_options = {}
723
for option in ['bbox_inches', 'pad_inches']:
724
if option in options:
725
save_options[option] = options.pop(option)
726
727
Config(**options)
728
729
if formats is None:
730
formats = ['pdf', 'eps']
731
732
try:
733
formats.remove('plotly')
734
Plotly(clf=False)
735
except ValueError:
736
pass
737
738
if root:
739
for fmt in formats:
740
SaveFormat(root, fmt, **save_options)
741
if clf:
742
Clf()
743
744
745
def SaveFormat(root, fmt='eps', **options):
746
"""Writes the current figure to a file in the given format.
747
748
Args:
749
root: string filename root
750
fmt: string format
751
"""
752
_Underride(options, dpi=300)
753
filename = '%s.%s' % (root, fmt)
754
print('Writing', filename)
755
plt.savefig(filename, format=fmt, **options)
756
757
758
# provide aliases for calling functions with lower-case names
759
preplot = PrePlot
760
subplot = SubPlot
761
clf = Clf
762
figure = Figure
763
plot = Plot
764
vlines = Vlines
765
hlines = Hlines
766
fill_between = FillBetween
767
text = Text
768
scatter = Scatter
769
pmf = Pmf
770
pmfs = Pmfs
771
hist = Hist
772
hists = Hists
773
diff = Diff
774
cdf = Cdf
775
cdfs = Cdfs
776
contour = Contour
777
pcolor = Pcolor
778
config = Config
779
show = Show
780
save = Save
781
782
783
def main():
784
color_iter = _Brewer.ColorGenerator(7)
785
for color in color_iter:
786
print(color)
787
788
789
if __name__ == '__main__':
790
main()
791
792