Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132923 views
License: OTHER
1
"""This file contains code for use with "Think Stats",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2014 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
from __future__ import print_function
9
10
import math
11
import matplotlib
12
import matplotlib.pyplot as pyplot
13
import numpy as np
14
import pandas
15
16
import warnings
17
18
# customize some matplotlib attributes
19
#matplotlib.rc('figure', figsize=(4, 3))
20
21
#matplotlib.rc('font', size=14.0)
22
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
23
#matplotlib.rc('legend', fontsize=20.0)
24
25
#matplotlib.rc('xtick.major', size=6.0)
26
#matplotlib.rc('xtick.minor', size=3.0)
27
28
#matplotlib.rc('ytick.major', size=6.0)
29
#matplotlib.rc('ytick.minor', size=3.0)
30
31
32
class _Brewer(object):
33
"""Encapsulates a nice sequence of colors.
34
35
Shades of blue that look good in color and can be distinguished
36
in grayscale (up to a point).
37
38
Borrowed from http://colorbrewer2.org/
39
"""
40
color_iter = None
41
42
colors = ['#081D58',
43
'#253494',
44
'#225EA8',
45
'#1D91C0',
46
'#41B6C4',
47
'#7FCDBB',
48
'#C7E9B4',
49
'#EDF8B1',
50
'#FFFFD9']
51
52
# lists that indicate which colors to use depending on how many are used
53
which_colors = [[],
54
[1],
55
[1, 3],
56
[0, 2, 4],
57
[0, 2, 4, 6],
58
[0, 2, 3, 5, 6],
59
[0, 2, 3, 4, 5, 6],
60
[0, 1, 2, 3, 4, 5, 6],
61
]
62
63
@classmethod
64
def Colors(cls):
65
"""Returns the list of colors.
66
"""
67
return cls.colors
68
69
@classmethod
70
def ColorGenerator(cls, n):
71
"""Returns an iterator of color strings.
72
73
n: how many colors will be used
74
"""
75
for i in cls.which_colors[n]:
76
yield cls.colors[i]
77
raise StopIteration('Ran out of colors in _Brewer.ColorGenerator')
78
79
@classmethod
80
def InitializeIter(cls, num):
81
"""Initializes the color iterator with the given number of colors."""
82
cls.color_iter = cls.ColorGenerator(num)
83
84
@classmethod
85
def ClearIter(cls):
86
"""Sets the color iterator to None."""
87
cls.color_iter = None
88
89
@classmethod
90
def GetIter(cls):
91
"""Gets the color iterator."""
92
if cls.color_iter is None:
93
cls.InitializeIter(7)
94
95
return cls.color_iter
96
97
98
def PrePlot(num=None, rows=None, cols=None):
99
"""Takes hints about what's coming.
100
101
num: number of lines that will be plotted
102
rows: number of rows of subplots
103
cols: number of columns of subplots
104
"""
105
if num:
106
_Brewer.InitializeIter(num)
107
108
if rows is None and cols is None:
109
return
110
111
if rows is not None and cols is None:
112
cols = 1
113
114
if cols is not None and rows is None:
115
rows = 1
116
117
# resize the image, depending on the number of rows and cols
118
size_map = {(1, 1): (8, 6),
119
(1, 2): (14, 6),
120
(1, 3): (14, 6),
121
(2, 2): (10, 10),
122
(2, 3): (16, 10),
123
(3, 1): (8, 10),
124
}
125
126
if (rows, cols) in size_map:
127
fig = pyplot.gcf()
128
fig.set_size_inches(*size_map[rows, cols])
129
130
# create the first subplot
131
if rows > 1 or cols > 1:
132
pyplot.subplot(rows, cols, 1)
133
global SUBPLOT_ROWS, SUBPLOT_COLS
134
SUBPLOT_ROWS = rows
135
SUBPLOT_COLS = cols
136
137
138
def SubPlot(plot_number, rows=None, cols=None):
139
"""Configures the number of subplots and changes the current plot.
140
141
rows: int
142
cols: int
143
plot_number: int
144
"""
145
rows = rows or SUBPLOT_ROWS
146
cols = cols or SUBPLOT_COLS
147
pyplot.subplot(rows, cols, plot_number)
148
149
150
def _Underride(d, **options):
151
"""Add key-value pairs to d only if key is not in d.
152
153
If d is None, create a new dictionary.
154
155
d: dictionary
156
options: keyword args to add to d
157
"""
158
if d is None:
159
d = {}
160
161
for key, val in options.items():
162
d.setdefault(key, val)
163
164
return d
165
166
167
def Clf():
168
"""Clears the figure and any hints that have been set."""
169
global LOC
170
LOC = None
171
_Brewer.ClearIter()
172
pyplot.clf()
173
fig = pyplot.gcf()
174
fig.set_size_inches(8, 6)
175
176
177
def Figure(**options):
178
"""Sets options for the current figure."""
179
_Underride(options, figsize=(6, 8))
180
pyplot.figure(**options)
181
182
183
def _UnderrideColor(options):
184
if 'color' in options:
185
return options
186
187
color_iter = _Brewer.GetIter()
188
189
if color_iter:
190
try:
191
options['color'] = next(color_iter)
192
except StopIteration:
193
# TODO: reconsider whether this should warn
194
# warnings.warn('Warning: Brewer ran out of colors.')
195
_Brewer.ClearIter()
196
return options
197
198
199
def Plot(obj, ys=None, style='', **options):
200
"""Plots a line.
201
202
Args:
203
obj: sequence of x values, or Series, or anything with Render()
204
ys: sequence of y values
205
style: style string passed along to pyplot.plot
206
options: keyword args passed to pyplot.plot
207
"""
208
options = _UnderrideColor(options)
209
label = getattr(obj, 'label', '_nolegend_')
210
options = _Underride(options, linewidth=3, alpha=0.8, label=label)
211
212
xs = obj
213
if ys is None:
214
if hasattr(obj, 'Render'):
215
xs, ys = obj.Render()
216
if isinstance(obj, pandas.Series):
217
ys = obj.values
218
xs = obj.index
219
220
if ys is None:
221
pyplot.plot(xs, style, **options)
222
else:
223
pyplot.plot(xs, ys, style, **options)
224
225
226
def FillBetween(xs, y1, y2=None, where=None, **options):
227
"""Plots a line.
228
229
Args:
230
xs: sequence of x values
231
y1: sequence of y values
232
y2: sequence of y values
233
where: sequence of boolean
234
options: keyword args passed to pyplot.fill_between
235
"""
236
options = _UnderrideColor(options)
237
options = _Underride(options, linewidth=0, alpha=0.5)
238
pyplot.fill_between(xs, y1, y2, where, **options)
239
240
241
def Bar(xs, ys, **options):
242
"""Plots a line.
243
244
Args:
245
xs: sequence of x values
246
ys: sequence of y values
247
options: keyword args passed to pyplot.bar
248
"""
249
options = _UnderrideColor(options)
250
options = _Underride(options, linewidth=0, alpha=0.6)
251
pyplot.bar(xs, ys, **options)
252
253
254
def Scatter(xs, ys=None, **options):
255
"""Makes a scatter plot.
256
257
xs: x values
258
ys: y values
259
options: options passed to pyplot.scatter
260
"""
261
options = _Underride(options, color='blue', alpha=0.2,
262
s=30, edgecolors='none')
263
264
if ys is None and isinstance(xs, pandas.Series):
265
ys = xs.values
266
xs = xs.index
267
268
pyplot.scatter(xs, ys, **options)
269
270
271
def HexBin(xs, ys, **options):
272
"""Makes a scatter plot.
273
274
xs: x values
275
ys: y values
276
options: options passed to pyplot.scatter
277
"""
278
options = _Underride(options, cmap=matplotlib.cm.Blues)
279
pyplot.hexbin(xs, ys, **options)
280
281
282
def Pdf(pdf, **options):
283
"""Plots a Pdf, Pmf, or Hist as a line.
284
285
Args:
286
pdf: Pdf, Pmf, or Hist object
287
options: keyword args passed to pyplot.plot
288
"""
289
low, high = options.pop('low', None), options.pop('high', None)
290
n = options.pop('n', 101)
291
xs, ps = pdf.Render(low=low, high=high, n=n)
292
options = _Underride(options, label=pdf.label)
293
Plot(xs, ps, **options)
294
295
296
def Pdfs(pdfs, **options):
297
"""Plots a sequence of PDFs.
298
299
Options are passed along for all PDFs. If you want different
300
options for each pdf, make multiple calls to Pdf.
301
302
Args:
303
pdfs: sequence of PDF objects
304
options: keyword args passed to pyplot.plot
305
"""
306
for pdf in pdfs:
307
Pdf(pdf, **options)
308
309
310
def Hist(hist, **options):
311
"""Plots a Pmf or Hist with a bar plot.
312
313
The default width of the bars is based on the minimum difference
314
between values in the Hist. If that's too small, you can override
315
it by providing a width keyword argument, in the same units
316
as the values.
317
318
Args:
319
hist: Hist or Pmf object
320
options: keyword args passed to pyplot.bar
321
"""
322
# find the minimum distance between adjacent values
323
xs, ys = hist.Render()
324
325
if 'width' not in options:
326
try:
327
options['width'] = 0.9 * np.diff(xs).min()
328
except TypeError:
329
warnings.warn("Hist: Can't compute bar width automatically."
330
"Check for non-numeric types in Hist."
331
"Or try providing width option."
332
)
333
334
options = _Underride(options, label=hist.label)
335
options = _Underride(options, align='center')
336
if options['align'] == 'left':
337
options['align'] = 'edge'
338
elif options['align'] == 'right':
339
options['align'] = 'edge'
340
options['width'] *= -1
341
342
Bar(xs, ys, **options)
343
344
345
def Hists(hists, **options):
346
"""Plots two histograms as interleaved bar plots.
347
348
Options are passed along for all PMFs. If you want different
349
options for each pmf, make multiple calls to Pmf.
350
351
Args:
352
hists: list of two Hist or Pmf objects
353
options: keyword args passed to pyplot.plot
354
"""
355
for hist in hists:
356
Hist(hist, **options)
357
358
359
def Pmf(pmf, **options):
360
"""Plots a Pmf or Hist as a line.
361
362
Args:
363
pmf: Hist or Pmf object
364
options: keyword args passed to pyplot.plot
365
"""
366
xs, ys = pmf.Render()
367
low, high = min(xs), max(xs)
368
369
width = options.pop('width', None)
370
if width is None:
371
try:
372
width = np.diff(xs).min()
373
except TypeError:
374
warnings.warn("Pmf: Can't compute bar width automatically."
375
"Check for non-numeric types in Pmf."
376
"Or try providing width option.")
377
points = []
378
379
lastx = np.nan
380
lasty = 0
381
for x, y in zip(xs, ys):
382
if (x - lastx) > 1e-5:
383
points.append((lastx, 0))
384
points.append((x, 0))
385
386
points.append((x, lasty))
387
points.append((x, y))
388
points.append((x+width, y))
389
390
lastx = x + width
391
lasty = y
392
points.append((lastx, 0))
393
pxs, pys = zip(*points)
394
395
align = options.pop('align', 'center')
396
if align == 'center':
397
pxs = np.array(pxs) - width/2.0
398
if align == 'right':
399
pxs = np.array(pxs) - width
400
401
options = _Underride(options, label=pmf.label)
402
Plot(pxs, pys, **options)
403
404
405
def Pmfs(pmfs, **options):
406
"""Plots a sequence of PMFs.
407
408
Options are passed along for all PMFs. If you want different
409
options for each pmf, make multiple calls to Pmf.
410
411
Args:
412
pmfs: sequence of PMF objects
413
options: keyword args passed to pyplot.plot
414
"""
415
for pmf in pmfs:
416
Pmf(pmf, **options)
417
418
419
def Diff(t):
420
"""Compute the differences between adjacent elements in a sequence.
421
422
Args:
423
t: sequence of number
424
425
Returns:
426
sequence of differences (length one less than t)
427
"""
428
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
429
return diffs
430
431
432
def Cdf(cdf, complement=False, transform=None, **options):
433
"""Plots a CDF as a line.
434
435
Args:
436
cdf: Cdf object
437
complement: boolean, whether to plot the complementary CDF
438
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
439
options: keyword args passed to pyplot.plot
440
441
Returns:
442
dictionary with the scale options that should be passed to
443
Config, Show or Save.
444
"""
445
xs, ps = cdf.Render()
446
xs = np.asarray(xs)
447
ps = np.asarray(ps)
448
449
scale = dict(xscale='linear', yscale='linear')
450
451
for s in ['xscale', 'yscale']:
452
if s in options:
453
scale[s] = options.pop(s)
454
455
if transform == 'exponential':
456
complement = True
457
scale['yscale'] = 'log'
458
459
if transform == 'pareto':
460
complement = True
461
scale['yscale'] = 'log'
462
scale['xscale'] = 'log'
463
464
if complement:
465
ps = [1.0-p for p in ps]
466
467
if transform == 'weibull':
468
xs = np.delete(xs, -1)
469
ps = np.delete(ps, -1)
470
ps = [-math.log(1.0-p) for p in ps]
471
scale['xscale'] = 'log'
472
scale['yscale'] = 'log'
473
474
if transform == 'gumbel':
475
xs = xp.delete(xs, 0)
476
ps = np.delete(ps, 0)
477
ps = [-math.log(p) for p in ps]
478
scale['yscale'] = 'log'
479
480
options = _Underride(options, label=cdf.label)
481
Plot(xs, ps, **options)
482
return scale
483
484
485
def Cdfs(cdfs, complement=False, transform=None, **options):
486
"""Plots a sequence of CDFs.
487
488
cdfs: sequence of CDF objects
489
complement: boolean, whether to plot the complementary CDF
490
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
491
options: keyword args passed to pyplot.plot
492
"""
493
for cdf in cdfs:
494
Cdf(cdf, complement, transform, **options)
495
496
497
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
498
"""Makes a contour plot.
499
500
d: map from (x, y) to z, or object that provides GetDict
501
pcolor: boolean, whether to make a pseudocolor plot
502
contour: boolean, whether to make a contour plot
503
imshow: boolean, whether to use pyplot.imshow
504
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
505
"""
506
try:
507
d = obj.GetDict()
508
except AttributeError:
509
d = obj
510
511
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
512
513
xs, ys = zip(*d.keys())
514
xs = sorted(set(xs))
515
ys = sorted(set(ys))
516
517
X, Y = np.meshgrid(xs, ys)
518
func = lambda x, y: d.get((x, y), 0)
519
func = np.vectorize(func)
520
Z = func(X, Y)
521
522
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
523
axes = pyplot.gca()
524
axes.xaxis.set_major_formatter(x_formatter)
525
526
if pcolor:
527
pyplot.pcolormesh(X, Y, Z, **options)
528
if contour:
529
cs = pyplot.contour(X, Y, Z, **options)
530
pyplot.clabel(cs, inline=1, fontsize=10)
531
if imshow:
532
extent = xs[0], xs[-1], ys[0], ys[-1]
533
pyplot.imshow(Z, extent=extent, **options)
534
535
536
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
537
"""Makes a pseudocolor plot.
538
539
xs:
540
ys:
541
zs:
542
pcolor: boolean, whether to make a pseudocolor plot
543
contour: boolean, whether to make a contour plot
544
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
545
"""
546
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
547
548
X, Y = np.meshgrid(xs, ys)
549
Z = zs
550
551
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
552
axes = pyplot.gca()
553
axes.xaxis.set_major_formatter(x_formatter)
554
555
if pcolor:
556
pyplot.pcolormesh(X, Y, Z, **options)
557
558
if contour:
559
cs = pyplot.contour(X, Y, Z, **options)
560
pyplot.clabel(cs, inline=1, fontsize=10)
561
562
563
def Text(x, y, s, **options):
564
"""Puts text in a figure.
565
566
x: number
567
y: number
568
s: string
569
options: keyword args passed to pyplot.text
570
"""
571
options = _Underride(options,
572
fontsize=16,
573
verticalalignment='top',
574
horizontalalignment='left')
575
pyplot.text(x, y, s, **options)
576
577
578
LEGEND = True
579
LOC = None
580
581
def Config(**options):
582
"""Configures the plot.
583
584
Pulls options out of the option dictionary and passes them to
585
the corresponding pyplot functions.
586
"""
587
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
588
'xticks', 'yticks', 'axis', 'xlim', 'ylim']
589
590
for name in names:
591
if name in options:
592
getattr(pyplot, name)(options[name])
593
594
# looks like this is not necessary: matplotlib understands text loc specs
595
loc_dict = {'upper right': 1,
596
'upper left': 2,
597
'lower left': 3,
598
'lower right': 4,
599
'right': 5,
600
'center left': 6,
601
'center right': 7,
602
'lower center': 8,
603
'upper center': 9,
604
'center': 10,
605
}
606
607
global LEGEND
608
LEGEND = options.get('legend', LEGEND)
609
610
if LEGEND:
611
global LOC
612
LOC = options.get('loc', LOC)
613
pyplot.legend(loc=LOC)
614
615
616
def Show(**options):
617
"""Shows the plot.
618
619
For options, see Config.
620
621
options: keyword args used to invoke various pyplot functions
622
"""
623
clf = options.pop('clf', True)
624
Config(**options)
625
pyplot.show()
626
if clf:
627
Clf()
628
629
630
def Plotly(**options):
631
"""Shows the plot.
632
633
For options, see Config.
634
635
options: keyword args used to invoke various pyplot functions
636
"""
637
clf = options.pop('clf', True)
638
Config(**options)
639
import plotly.plotly as plotly
640
url = plotly.plot_mpl(pyplot.gcf())
641
if clf:
642
Clf()
643
return url
644
645
646
def Save(root=None, formats=None, **options):
647
"""Saves the plot in the given formats and clears the figure.
648
649
For options, see Config.
650
651
Args:
652
root: string filename root
653
formats: list of string formats
654
options: keyword args used to invoke various pyplot functions
655
"""
656
clf = options.pop('clf', True)
657
Config(**options)
658
659
if formats is None:
660
formats = ['pdf', 'eps']
661
662
try:
663
formats.remove('plotly')
664
Plotly(clf=False)
665
except ValueError:
666
pass
667
668
if root:
669
for fmt in formats:
670
SaveFormat(root, fmt)
671
if clf:
672
Clf()
673
674
675
def SaveFormat(root, fmt='eps'):
676
"""Writes the current figure to a file in the given format.
677
678
Args:
679
root: string filename root
680
fmt: string format
681
"""
682
filename = '%s.%s' % (root, fmt)
683
print('Writing', filename)
684
pyplot.savefig(filename, format=fmt, dpi=300)
685
686
687
# provide aliases for calling functons with lower-case names
688
preplot = PrePlot
689
subplot = SubPlot
690
clf = Clf
691
figure = Figure
692
plot = Plot
693
text = Text
694
scatter = Scatter
695
pmf = Pmf
696
pmfs = Pmfs
697
hist = Hist
698
hists = Hists
699
diff = Diff
700
cdf = Cdf
701
cdfs = Cdfs
702
contour = Contour
703
pcolor = Pcolor
704
config = Config
705
show = Show
706
save = Save
707
708
709
def main():
710
color_iter = _Brewer.ColorGenerator(7)
711
for color in color_iter:
712
print(color)
713
714
715
if __name__ == '__main__':
716
main()
717
718