📚 The CoCalc Library - books, templates and other resources
License: OTHER
"""This file contains code used in "Think Bayes",1by Allen B. Downey, available from greenteapress.com23Copyright 2012 Allen B. Downey4License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html5"""67import matplotlib.pyplot as pyplot8import thinkplot9import numpy1011import csv12import random13import shelve14import sys15import time1617import thinkbayes1819import warnings2021warnings.simplefilter('error', RuntimeWarning)222324FORMATS = ['pdf', 'eps', 'png']252627class Locker(object):28"""Encapsulates a shelf for storing key-value pairs."""2930def __init__(self, shelf_file):31self.shelf = shelve.open(shelf_file)3233def Close(self):34"""Closes the shelf.35"""36self.shelf.close()3738def Add(self, key, value):39"""Adds a key-value pair."""40self.shelf[str(key)] = value4142def Lookup(self, key):43"""Looks up a key."""44return self.shelf.get(str(key))4546def Keys(self):47"""Returns an iterator of keys."""48return self.shelf.iterkeys()4950def Read(self):51"""Returns the contents of the shelf as a map."""52return dict(self.shelf)535455class Subject(object):56"""Represents a subject from the belly button study."""5758def __init__(self, code):59"""60code: string ID61species: sequence of (int count, string species) pairs62"""63self.code = code64self.species = []65self.suite = None66self.num_reads = None67self.num_species = None68self.total_reads = None69self.total_species = None70self.prev_unseen = None71self.pmf_n = None72self.pmf_q = None73self.pmf_l = None7475def Add(self, species, count):76"""Add a species-count pair.7778It is up to the caller to ensure that species names are unique.7980species: string species/genus name81count: int number of individuals82"""83self.species.append((count, species))8485def Done(self, reverse=False, clean_param=0):86"""Called when we are done adding species counts.8788reverse: which order to sort in89"""90if clean_param:91self.Clean(clean_param)9293self.species.sort(reverse=reverse)94counts = self.GetCounts()95self.num_species = len(counts)96self.num_reads = sum(counts)9798def Clean(self, clean_param=50):99"""Identifies and removes bogus data.100101clean_param: parameter that controls the number of legit species102"""103def prob_bogus(k, r):104"""Compute the probability that a species is bogus."""105q = clean_param / r106p = (1-q) ** k107return p108109print self.code, clean_param110111counts = self.GetCounts()112r = 1.0 * sum(counts)113114species_seq = []115for k, species in sorted(self.species):116117if random.random() < prob_bogus(k, r):118continue119species_seq.append((k, species))120self.species = species_seq121122def GetM(self):123"""Gets number of observed species."""124return len(self.species)125126def GetCounts(self):127"""Gets the list of species counts128129Should be in increasing order, if Sort() has been invoked.130"""131return [count for count, _ in self.species]132133def MakeCdf(self):134"""Makes a CDF of total prevalence vs rank."""135counts = self.GetCounts()136counts.sort(reverse=True)137cdf = thinkbayes.MakeCdfFromItems(enumerate(counts))138return cdf139140def GetNames(self):141"""Gets the names of the seen species."""142return [name for _, name in self.species]143144def PrintCounts(self):145"""Prints the counts and species names."""146for count, name in reversed(self.species):147print count, name148149def GetSpecies(self, index):150"""Gets the count and name of the indicated species.151152Returns: count-species pair153"""154return self.species[index]155156def GetCdf(self):157"""Returns cumulative prevalence vs number of species.158"""159counts = self.GetCounts()160items = enumerate(counts)161cdf = thinkbayes.MakeCdfFromItems(items)162return cdf163164def GetPrevalences(self):165"""Returns a sequence of prevalences (normalized counts).166"""167counts = self.GetCounts()168total = sum(counts)169prevalences = numpy.array(counts, dtype=numpy.float) / total170return prevalences171172def Process(self, low=None, high=500, conc=1, iters=100):173"""Computes the posterior distribution of n and the prevalences.174175Sets attribute: self.suite176177low: minimum number of species178high: maximum number of species179conc: concentration parameter180iters: number of iterations to use in the estimator181"""182counts = self.GetCounts()183m = len(counts)184if low is None:185low = max(m, 2)186ns = range(low, high+1)187188#start = time.time()189self.suite = Species5(ns, conc=conc, iters=iters)190self.suite.Update(counts)191#end = time.time()192193#print 'Processing time' end-start194195def MakePrediction(self, num_sims=100):196"""Make predictions for the given subject.197198Precondition: Process has run199200num_sims: how many simulations to run for predictions201202Adds attributes203pmf_l: predictive distribution of additional species204"""205add_reads = self.total_reads - self.num_reads206curves = self.RunSimulations(num_sims, add_reads)207self.pmf_l = self.MakePredictive(curves)208209def MakeQuickPrediction(self, num_sims=100):210"""Make predictions for the given subject.211212Precondition: Process has run213214num_sims: how many simulations to run for predictions215216Adds attribute:217pmf_l: predictive distribution of additional species218"""219add_reads = self.total_reads - self.num_reads220pmf = thinkbayes.Pmf()221_, seen = self.GetSeenSpecies()222223for _ in range(num_sims):224_, observations = self.GenerateObservations(add_reads)225all_seen = seen.union(observations)226l = len(all_seen) - len(seen)227pmf.Incr(l)228229pmf.Normalize()230self.pmf_l = pmf231232def DistL(self):233"""Returns the distribution of additional species, l.234"""235return self.pmf_l236237def MakeFigures(self):238"""Makes figures showing distribution of n and the prevalences."""239self.PlotDistN()240self.PlotPrevalences()241242def PlotDistN(self):243"""Plots distribution of n."""244pmf = self.suite.DistN()245print '90% CI for N:', pmf.CredibleInterval(90)246pmf.name = self.code247248thinkplot.Clf()249thinkplot.PrePlot(num=1)250251thinkplot.Pmf(pmf)252253root = 'species-ndist-%s' % self.code254thinkplot.Save(root=root,255xlabel='Number of species',256ylabel='Prob',257formats=FORMATS,258)259260def PlotPrevalences(self, num=5):261"""Plots dist of prevalence for several species.262263num: how many species (starting with the highest prevalence)264"""265thinkplot.Clf()266thinkplot.PrePlot(num=5)267268for rank in range(1, num+1):269self.PlotPrevalence(rank)270271root = 'species-prev-%s' % self.code272thinkplot.Save(root=root,273xlabel='Prevalence',274ylabel='Prob',275formats=FORMATS,276axis=[0, 0.3, 0, 1],277)278279def PlotPrevalence(self, rank=1, cdf_flag=True):280"""Plots dist of prevalence for one species.281282rank: rank order of the species to plot.283cdf_flag: whether to plot the CDF284"""285# convert rank to index286index = self.GetM() - rank287288_, mix = self.suite.DistOfPrevalence(index)289count, _ = self.GetSpecies(index)290mix.name = '%d (%d)' % (rank, count)291292print '90%% CI for prevalence of species %d:' % rank,293print mix.CredibleInterval(90)294295if cdf_flag:296cdf = mix.MakeCdf()297thinkplot.Cdf(cdf)298else:299thinkplot.Pmf(mix)300301def PlotMixture(self, rank=1):302"""Plots dist of prevalence for all n, and the mix.303304rank: rank order of the species to plot305"""306# convert rank to index307index = self.GetM() - rank308309print self.GetSpecies(index)310print self.GetCounts()[index]311312metapmf, mix = self.suite.DistOfPrevalence(index)313314thinkplot.Clf()315for pmf in metapmf.Values():316thinkplot.Pmf(pmf, color='blue', alpha=0.2, linewidth=0.5)317318thinkplot.Pmf(mix, color='blue', alpha=0.9, linewidth=2)319320root = 'species-mix-%s' % self.code321thinkplot.Save(root=root,322xlabel='Prevalence',323ylabel='Prob',324formats=FORMATS,325axis=[0, 0.3, 0, 0.3],326legend=False)327328def GetSeenSpecies(self):329"""Makes a set of the names of seen species.330331Returns: number of species, set of string species names332"""333names = self.GetNames()334m = len(names)335seen = set(SpeciesGenerator(names, m))336return m, seen337338def GenerateObservations(self, num_reads):339"""Generates a series of random observations.340341num_reads: number of reads to generate342343Returns: number of species, sequence of string species names344"""345n, prevalences = self.suite.SamplePosterior()346347names = self.GetNames()348name_iter = SpeciesGenerator(names, n)349350items = zip(name_iter, prevalences)351352cdf = thinkbayes.MakeCdfFromItems(items)353observations = cdf.Sample(num_reads)354355#for ob in observations:356# print ob357358return n, observations359360def Resample(self, num_reads):361"""Choose a random subset of the data (without replacement).362363num_reads: number of reads in the subset364"""365t = []366for count, species in self.species:367t.extend([species]*count)368369random.shuffle(t)370reads = t[:num_reads]371372subject = Subject(self.code)373hist = thinkbayes.MakeHistFromList(reads)374for species, count in hist.Items():375subject.Add(species, count)376377subject.Done()378return subject379380def Match(self, match):381"""Match up a rarefied subject with a complete subject.382383match: complete Subject384385Assigns attributes:386total_reads:387total_species:388prev_unseen:389"""390self.total_reads = match.num_reads391self.total_species = match.num_species392393# compute the prevalence of unseen species (at least approximately,394# based on all species counts in match395_, seen = self.GetSeenSpecies()396397seen_total = 0.0398unseen_total = 0.0399for count, species in match.species:400if species in seen:401seen_total += count402else:403unseen_total += count404405self.prev_unseen = unseen_total / (seen_total + unseen_total)406407def RunSimulation(self, num_reads, frac_flag=False, jitter=0.01):408"""Simulates additional observations and returns a rarefaction curve.409410k is the number of additional observations411num_new is the number of new species seen412413num_reads: how many new reads to simulate414frac_flag: whether to convert to fraction of species seen415jitter: size of jitter added if frac_flag is true416417Returns: list of (k, num_new) pairs418"""419m, seen = self.GetSeenSpecies()420n, observations = self.GenerateObservations(num_reads)421422curve = []423for i, obs in enumerate(observations):424seen.add(obs)425426if frac_flag:427frac_seen = len(seen) / float(n)428frac_seen += random.uniform(-jitter, jitter)429curve.append((i+1, frac_seen))430else:431num_new = len(seen) - m432curve.append((i+1, num_new))433434return curve435436def RunSimulations(self, num_sims, num_reads, frac_flag=False):437"""Runs simulations and returns a list of curves.438439Each curve is a sequence of (k, num_new) pairs.440441num_sims: how many simulations to run442num_reads: how many samples to generate in each simulation443frac_flag: whether to convert num_new to fraction of total444"""445curves = [self.RunSimulation(num_reads, frac_flag)446for _ in range(num_sims)]447return curves448449def MakePredictive(self, curves):450"""Makes a predictive distribution of additional species.451452curves: list of (k, num_new) curves453454Returns: Pmf of num_new455"""456pred = thinkbayes.Pmf(name=self.code)457for curve in curves:458_, last_num_new = curve[-1]459pred.Incr(last_num_new)460pred.Normalize()461return pred462463464def MakeConditionals(curves, ks):465"""Makes Cdfs of the distribution of num_new conditioned on k.466467curves: list of (k, num_new) curves468ks: list of values of k469470Returns: list of Cdfs471"""472joint = MakeJointPredictive(curves)473474cdfs = []475for k in ks:476pmf = joint.Conditional(1, 0, k)477pmf.name = 'k=%d' % k478cdf = pmf.MakeCdf()479cdfs.append(cdf)480print '90%% credible interval for %d' % k,481print cdf.CredibleInterval(90)482return cdfs483484485def MakeJointPredictive(curves):486"""Makes a joint distribution of k and num_new.487488curves: list of (k, num_new) curves489490Returns: joint Pmf of (k, num_new)491"""492joint = thinkbayes.Joint()493for curve in curves:494for k, num_new in curve:495joint.Incr((k, num_new))496joint.Normalize()497return joint498499500def MakeFracCdfs(curves, ks):501"""Makes Cdfs of the fraction of species seen.502503curves: list of (k, num_new) curves504505Returns: list of Cdfs506"""507d = {}508for curve in curves:509for k, frac in curve:510if k in ks:511d.setdefault(k, []).append(frac)512513cdfs = {}514for k, fracs in d.iteritems():515cdf = thinkbayes.MakeCdfFromList(fracs)516cdfs[k] = cdf517518return cdfs519520def SpeciesGenerator(names, num):521"""Generates a series of names, starting with the given names.522523Additional names are 'unseen' plus a serial number.524525names: list of strings526num: total number of species names to generate527528Returns: string iterator529"""530i = 0531for name in names:532yield name533i += 1534535while i < num:536yield 'unseen-%d' % i537i += 1538539540def ReadRarefactedData(filename='journal.pone.0047712.s001.csv',541clean_param=0):542"""Reads a data file and returns a list of Subjects.543544Data from http://www.plosone.org/article/545info%3Adoi%2F10.1371%2Fjournal.pone.0047712#s4546547filename: string filename to read548clean_param: parameter passed to Clean549550Returns: map from code to Subject551"""552fp = open(filename)553reader = csv.reader(fp)554_ = reader.next()555556subject = Subject('')557subject_map = {}558559i = 0560for t in reader:561code = t[0]562if code != subject.code:563# start a new subject564subject = Subject(code)565subject_map[code] = subject566567# append a number to the species names so they're unique568species = t[1]569species = '%s-%d' % (species, i)570i += 1571572count = int(t[2])573subject.Add(species, count)574575for code, subject in subject_map.iteritems():576subject.Done(clean_param=clean_param)577578return subject_map579580581def ReadCompleteDataset(filename='BBB_data_from_Rob.csv', clean_param=0):582"""Reads a data file and returns a list of Subjects.583584Data from personal correspondence with Rob Dunn, received 2-7-13.585Converted from xlsx to csv.586587filename: string filename to read588clean_param: parameter passed to Clean589590Returns: map from code to Subject591"""592fp = open(filename)593reader = csv.reader(fp)594header = reader.next()595header = reader.next()596597subject_codes = header[1:-1]598subject_codes = ['B'+code for code in subject_codes]599600# create the subject map601uber_subject = Subject('uber')602subject_map = {}603for code in subject_codes:604subject_map[code] = Subject(code)605606# read lines607i = 0608for t in reader:609otu_code = t[0]610if otu_code == '':611continue612613# pull out a species name and give it a number614otu_names = t[-1]615taxons = otu_names.split(';')616species = taxons[-1]617species = '%s-%d' % (species, i)618i += 1619620counts = [int(x) for x in t[1:-1]]621622# print otu_code, species623624for code, count in zip(subject_codes, counts):625if count > 0:626subject_map[code].Add(species, count)627uber_subject.Add(species, count)628629uber_subject.Done(clean_param=clean_param)630for code, subject in subject_map.iteritems():631subject.Done(clean_param=clean_param)632633return subject_map, uber_subject634635636def JoinSubjects():637"""Reads both datasets and computers their inner join.638639Finds all subjects that appear in both datasets.640641For subjects in the rarefacted dataset, looks up the total642number of reads and stores it as total_reads. num_reads643is normally 400.644645Returns: map from code to Subject646"""647648# read the rarefacted dataset649sampled_subjects = ReadRarefactedData()650651# read the complete dataset652all_subjects, _ = ReadCompleteDataset()653654for code, subject in sampled_subjects.iteritems():655if code in all_subjects:656match = all_subjects[code]657subject.Match(match)658659return sampled_subjects660661662def JitterCurve(curve, dx=0.2, dy=0.3):663"""Adds random noise to the pairs in a curve.664665dx and dy control the amplitude of the noise in each dimension.666"""667curve = [(x+random.uniform(-dx, dx),668y+random.uniform(-dy, dy)) for x, y in curve]669return curve670671672def OffsetCurve(curve, i, n, dx=0.3, dy=0.3):673"""Adds random noise to the pairs in a curve.674675i is the index of the curve676n is the number of curves677678dx and dy control the amplitude of the noise in each dimension.679"""680xoff = -dx + 2 * dx * i / (n-1)681yoff = -dy + 2 * dy * i / (n-1)682curve = [(x+xoff, y+yoff) for x, y in curve]683return curve684685686def PlotCurves(curves, root='species-rare'):687"""Plots a set of curves.688689curves is a list of curves; each curve is a list of (x, y) pairs.690"""691thinkplot.Clf()692color = '#225EA8'693694n = len(curves)695for i, curve in enumerate(curves):696curve = OffsetCurve(curve, i, n)697xs, ys = zip(*curve)698thinkplot.Plot(xs, ys, color=color, alpha=0.3, linewidth=0.5)699700thinkplot.Save(root=root,701xlabel='# samples',702ylabel='# species',703formats=FORMATS,704legend=False)705706707def PlotConditionals(cdfs, root='species-cond'):708"""Plots cdfs of num_new conditioned on k.709710cdfs: list of Cdf711root: string filename root712"""713thinkplot.Clf()714thinkplot.PrePlot(num=len(cdfs))715716thinkplot.Cdfs(cdfs)717718thinkplot.Save(root=root,719xlabel='# new species',720ylabel='Prob',721formats=FORMATS)722723724def PlotFracCdfs(cdfs, root='species-frac'):725"""Plots CDFs of the fraction of species seen.726727cdfs: map from k to CDF of fraction of species seen after k samples728"""729thinkplot.Clf()730color = '#225EA8'731732for k, cdf in cdfs.iteritems():733xs, ys = cdf.Render()734ys = [1-y for y in ys]735thinkplot.Plot(xs, ys, color=color, linewidth=1)736737x = 0.9738y = 1 - cdf.Prob(x)739pyplot.text(x, y, str(k), fontsize=9, color=color,740horizontalalignment='center',741verticalalignment='center',742bbox=dict(facecolor='white', edgecolor='none'))743744thinkplot.Save(root=root,745xlabel='Fraction of species seen',746ylabel='Probability',747formats=FORMATS,748legend=False)749750751class Species(thinkbayes.Suite):752"""Represents hypotheses about the number of species."""753754def __init__(self, ns, conc=1, iters=1000):755hypos = [thinkbayes.Dirichlet(n, conc) for n in ns]756thinkbayes.Suite.__init__(self, hypos)757self.iters = iters758759def Update(self, data):760"""Updates the suite based on the data.761762data: list of observed frequencies763"""764# call Update in the parent class, which calls Likelihood765thinkbayes.Suite.Update(self, data)766767# update the next level of the hierarchy768for hypo in self.Values():769hypo.Update(data)770771def Likelihood(self, data, hypo):772"""Computes the likelihood of the data under this hypothesis.773774hypo: Dirichlet object775data: list of observed frequencies776"""777dirichlet = hypo778779# draw sample Likelihoods from the hypothetical Dirichlet dist780# and add them up781like = 0782for _ in range(self.iters):783like += dirichlet.Likelihood(data)784785# correct for the number of ways the observed species786# might have been chosen from all species787m = len(data)788like *= thinkbayes.BinomialCoef(dirichlet.n, m)789790return like791792def DistN(self):793"""Computes the distribution of n."""794pmf = thinkbayes.Pmf()795for hypo, prob in self.Items():796pmf.Set(hypo.n, prob)797return pmf798799800class Species2(object):801"""Represents hypotheses about the number of species.802803Combines two layers of the hierarchy into one object.804805ns and probs represent the distribution of N806807params represents the parameters of the Dirichlet distributions808"""809810def __init__(self, ns, conc=1, iters=1000):811self.ns = ns812self.conc = conc813self.probs = numpy.ones(len(ns), dtype=numpy.float)814self.params = numpy.ones(self.ns[-1], dtype=numpy.float) * conc815self.iters = iters816self.num_reads = 0817self.m = 0818819def Preload(self, data):820"""Change the initial parameters to fit the data better.821822Just an experiment. Doesn't work.823"""824m = len(data)825singletons = data.count(1)826num = m - singletons827print m, singletons, num828addend = numpy.ones(num, dtype=numpy.float) * 1829print len(addend)830print len(self.params[singletons:m])831self.params[singletons:m] += addend832print 'Preload', num833834def Update(self, data):835"""Updates the distribution based on data.836837data: numpy array of counts838"""839self.num_reads += sum(data)840841like = numpy.zeros(len(self.ns), dtype=numpy.float)842for _ in range(self.iters):843like += self.SampleLikelihood(data)844845self.probs *= like846self.probs /= self.probs.sum()847848self.m = len(data)849#self.params[:self.m] += data * self.conc850self.params[:self.m] += data851852def SampleLikelihood(self, data):853"""Computes the likelihood of the data for all values of n.854855Draws one sample from the distribution of prevalences.856857data: sequence of observed counts858859Returns: numpy array of m likelihoods860"""861gammas = numpy.random.gamma(self.params)862863m = len(data)864row = gammas[:m]865col = numpy.cumsum(gammas)866867log_likes = []868for n in self.ns:869ps = row / col[n-1]870terms = numpy.log(ps) * data871log_like = terms.sum()872log_likes.append(log_like)873874log_likes -= numpy.max(log_likes)875likes = numpy.exp(log_likes)876877coefs = [thinkbayes.BinomialCoef(n, m) for n in self.ns]878likes *= coefs879880return likes881882def DistN(self):883"""Computes the distribution of n.884885Returns: new Pmf object886"""887pmf = thinkbayes.MakePmfFromItems(zip(self.ns, self.probs))888return pmf889890def RandomN(self):891"""Returns a random value of n."""892return self.DistN().Random()893894def DistQ(self, iters=100):895"""Computes the distribution of q based on distribution of n.896897Returns: pmf of q898"""899cdf_n = self.DistN().MakeCdf()900sample_n = cdf_n.Sample(iters)901902pmf = thinkbayes.Pmf()903for n in sample_n:904q = self.RandomQ(n)905pmf.Incr(q)906907pmf.Normalize()908return pmf909910def RandomQ(self, n):911"""Returns a random value of q.912913Based on n, self.num_reads and self.conc.914915n: number of species916917Returns: q918"""919# generate random prevalences920dirichlet = thinkbayes.Dirichlet(n, conc=self.conc)921prevalences = dirichlet.Random()922923# generate a simulated sample924pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))925cdf = pmf.MakeCdf()926sample = cdf.Sample(self.num_reads)927seen = set(sample)928929# add up the prevalence of unseen species930q = 0931for species, prev in enumerate(prevalences):932if species not in seen:933q += prev934935return q936937def MarginalBeta(self, n, index):938"""Computes the conditional distribution of the indicated species.939940n: conditional number of species941index: which species942943Returns: Beta object representing a distribution of prevalence.944"""945alpha0 = self.params[:n].sum()946alpha = self.params[index]947return thinkbayes.Beta(alpha, alpha0-alpha)948949def DistOfPrevalence(self, index):950"""Computes the distribution of prevalence for the indicated species.951952index: which species953954Returns: (metapmf, mix) where metapmf is a MetaPmf and mix is a Pmf955"""956metapmf = thinkbayes.Pmf()957958for n, prob in zip(self.ns, self.probs):959beta = self.MarginalBeta(n, index)960pmf = beta.MakePmf()961metapmf.Set(pmf, prob)962963mix = thinkbayes.MakeMixture(metapmf)964return metapmf, mix965966def SamplePosterior(self):967"""Draws random n and prevalences.968969Returns: (n, prevalences)970"""971n = self.RandomN()972prevalences = self.SamplePrevalences(n)973974#print 'Peeking at n_cheat'975#n = n_cheat976977return n, prevalences978979def SamplePrevalences(self, n):980"""Draws a sample of prevalences given n.981982n: the number of species assumed in the conditional983984Returns: numpy array of n prevalences985"""986if n == 1:987return [1.0]988989q_desired = self.RandomQ(n)990q_desired = max(q_desired, 1e-6)991992params = self.Unbias(n, self.m, q_desired)993994gammas = numpy.random.gamma(params)995gammas /= gammas.sum()996return gammas997998def Unbias(self, n, m, q_desired):999"""Adjusts the parameters to achieve desired prev_unseen (q).10001001n: number of species1002m: seen species1003q_desired: prevalence of unseen species1004"""1005params = self.params[:n].copy()10061007if n == m:1008return params10091010x = sum(params[:m])1011y = sum(params[m:])1012a = x + y1013#print x, y, a, x/a, y/a10141015g = q_desired * a / y1016f = (a - g * y) / x1017params[:m] *= f1018params[m:] *= g10191020return params102110221023class Species3(Species2):1024"""Represents hypotheses about the number of species."""10251026def Update(self, data):1027"""Updates the suite based on the data.10281029data: list of observations1030"""1031# sample the likelihoods and add them up1032like = numpy.zeros(len(self.ns), dtype=numpy.float)1033for _ in range(self.iters):1034like += self.SampleLikelihood(data)10351036self.probs *= like1037self.probs /= self.probs.sum()10381039m = len(data)1040self.params[:m] += data10411042def SampleLikelihood(self, data):1043"""Computes the likelihood of the data under all hypotheses.10441045data: list of observations1046"""1047# get a random sample1048gammas = numpy.random.gamma(self.params)10491050# row is just the first m elements of gammas1051m = len(data)1052row = gammas[:m]10531054# col is the cumulative sum of gammas1055col = numpy.cumsum(gammas)[self.ns[0]-1:]10561057# each row of the array is a set of ps, normalized1058# for each hypothetical value of n1059array = row / col[:, numpy.newaxis]10601061# computing the multinomial PDF under a log transform1062# take the log of the ps and multiply by the data1063terms = numpy.log(array) * data10641065# add up the rows1066log_likes = terms.sum(axis=1)10671068# before exponentiating, scale into a reasonable range1069log_likes -= numpy.max(log_likes)1070likes = numpy.exp(log_likes)10711072# correct for the number of ways we could see m species1073# out of a possible n1074coefs = [thinkbayes.BinomialCoef(n, m) for n in self.ns]1075likes *= coefs10761077return likes107810791080class Species4(Species):1081"""Represents hypotheses about the number of species."""10821083def Update(self, data):1084"""Updates the suite based on the data.10851086data: list of observed frequencies1087"""1088m = len(data)10891090# loop through the species and update one at a time1091for i in range(m):1092one = numpy.zeros(i+1)1093one[i] = data[i]10941095# call the parent class1096Species.Update(self, one)10971098def Likelihood(self, data, hypo):1099"""Computes the likelihood of the data under this hypothesis.11001101Note: this only works correctly if we update one species at a time.11021103hypo: Dirichlet object1104data: list of observed frequencies1105"""1106dirichlet = hypo1107like = 01108for _ in range(self.iters):1109like += dirichlet.Likelihood(data)11101111# correct for the number of unseen species the new one1112# could have been1113m = len(data)1114num_unseen = dirichlet.n - m + 11115like *= num_unseen11161117return like111811191120class Species5(Species2):1121"""Represents hypotheses about the number of species.11221123Combines two laters of the hierarchy into one object.11241125ns and probs represent the distribution of N11261127params represents the parameters of the Dirichlet distributions1128"""11291130def Update(self, data):1131"""Updates the suite based on the data.11321133data: list of observed frequencies in increasing order1134"""1135# loop through the species and update one at a time1136m = len(data)1137for i in range(m):1138self.UpdateOne(i+1, data[i])1139self.params[i] += data[i]11401141def UpdateOne(self, i, count):1142"""Updates the suite based on the data.11431144Evaluates the likelihood for all values of n.11451146i: which species was observed (1..n)1147count: how many were observed1148"""1149# how many species have we seen so far1150self.m = i11511152# how many reads have we seen1153self.num_reads += count11541155if self.iters == 0:1156return11571158# sample the likelihoods and add them up1159likes = numpy.zeros(len(self.ns), dtype=numpy.float)1160for _ in range(self.iters):1161likes += self.SampleLikelihood(i, count)11621163# correct for the number of unseen species the new one1164# could have been1165unseen_species = [n-i+1 for n in self.ns]1166likes *= unseen_species11671168# multiply the priors by the likelihoods and renormalize1169self.probs *= likes1170self.probs /= self.probs.sum()11711172def SampleLikelihood(self, i, count):1173"""Computes the likelihood of the data under all hypotheses.11741175i: which species was observed1176count: how many were observed1177"""1178# get a random sample of p1179gammas = numpy.random.gamma(self.params)11801181# sums is the cumulative sum of p, for each value of n1182sums = numpy.cumsum(gammas)[self.ns[0]-1:]11831184# get p for the mth species, for each value of n1185ps = gammas[i-1] / sums1186log_likes = numpy.log(ps) * count11871188# before exponentiating, scale into a reasonable range1189log_likes -= numpy.max(log_likes)1190likes = numpy.exp(log_likes)11911192return likes119311941195def MakePosterior(constructor, data, ns, conc=1, iters=1000):1196"""Makes a suite, updates it and returns the posterior suite.11971198Prints the elapsed time.11991200data: observed species and their counts1201ns: sequence of hypothetical ns1202conc: concentration parameter1203iters: how many samples to draw12041205Returns: posterior suite of the given type1206"""1207suite = constructor(ns, conc=conc, iters=iters)12081209# print constructor.__name__1210start = time.time()1211suite.Update(data)1212end = time.time()1213print 'Processing time', end-start12141215return suite121612171218def PlotAllVersions():1219"""Makes a graph of posterior distributions of N."""1220data = [1, 2, 3]1221m = len(data)1222n = 201223ns = range(m, n)12241225for constructor in [Species, Species2, Species3, Species4, Species5]:1226suite = MakePosterior(constructor, data, ns)1227pmf = suite.DistN()1228pmf.name = '%s' % (constructor.__name__)1229thinkplot.Pmf(pmf)12301231thinkplot.Save(root='species3',1232xlabel='Number of species',1233ylabel='Prob')123412351236def PlotMedium():1237"""Makes a graph of posterior distributions of N."""1238data = [1, 1, 1, 1, 2, 3, 5, 9]1239m = len(data)1240n = 201241ns = range(m, n)12421243for constructor in [Species, Species2, Species3, Species4, Species5]:1244suite = MakePosterior(constructor, data, ns)1245pmf = suite.DistN()1246pmf.name = '%s' % (constructor.__name__)1247thinkplot.Pmf(pmf)12481249thinkplot.Show()125012511252def SimpleDirichletExample():1253"""Makes a plot showing posterior distributions for three species.12541255This is the case where we know there are exactly three species.1256"""1257thinkplot.Clf()1258thinkplot.PrePlot(3)12591260names = ['lions', 'tigers', 'bears']1261data = [3, 2, 1]12621263dirichlet = thinkbayes.Dirichlet(3)1264for i in range(3):1265beta = dirichlet.MarginalBeta(i)1266print 'mean', names[i], beta.Mean()12671268dirichlet.Update(data)1269for i in range(3):1270beta = dirichlet.MarginalBeta(i)1271print 'mean', names[i], beta.Mean()12721273pmf = beta.MakePmf(name=names[i])1274thinkplot.Pmf(pmf)12751276thinkplot.Save(root='species1',1277xlabel='Prevalence',1278ylabel='Prob',1279formats=FORMATS,1280)128112821283def HierarchicalExample():1284"""Shows the posterior distribution of n for lions, tigers and bears.1285"""1286ns = range(3, 30)1287suite = Species(ns, iters=8000)12881289data = [3, 2, 1]1290suite.Update(data)12911292thinkplot.Clf()1293thinkplot.PrePlot(num=1)12941295pmf = suite.DistN()1296thinkplot.Pmf(pmf)1297thinkplot.Save(root='species2',1298xlabel='Number of species',1299ylabel='Prob',1300formats=FORMATS,1301)130213031304def CompareHierarchicalExample():1305"""Makes a graph of posterior distributions of N."""1306data = [3, 2, 1]1307m = len(data)1308n = 301309ns = range(m, n)13101311constructors = [Species, Species5]1312iters = [1000, 100]13131314for constructor, iters in zip(constructors, iters):1315suite = MakePosterior(constructor, data, ns, iters)1316pmf = suite.DistN()1317pmf.name = '%s' % (constructor.__name__)1318thinkplot.Pmf(pmf)13191320thinkplot.Show()132113221323def ProcessSubjects(codes):1324"""Process subjects with the given codes and plot their posteriors.13251326code: sequence of string codes1327"""1328thinkplot.Clf()1329thinkplot.PrePlot(len(codes))13301331subjects = ReadRarefactedData()1332pmfs = []1333for code in codes:1334subject = subjects[code]13351336subject.Process()1337pmf = subject.suite.DistN()1338pmf.name = subject.code1339thinkplot.Pmf(pmf)13401341pmfs.append(pmf)13421343print 'ProbGreater', thinkbayes.PmfProbGreater(pmfs[0], pmfs[1])1344print 'ProbLess', thinkbayes.PmfProbLess(pmfs[0], pmfs[1])13451346thinkplot.Save(root='species4',1347xlabel='Number of species',1348ylabel='Prob',1349formats=FORMATS,1350)135113521353def RunSubject(code, conc=1, high=500):1354"""Run the analysis for the subject with the given code.13551356code: string code1357"""1358subjects = JoinSubjects()1359subject = subjects[code]13601361subject.Process(conc=conc, high=high, iters=300)1362subject.MakeQuickPrediction()13631364PrintSummary(subject)1365actual_l = subject.total_species - subject.num_species1366cdf_l = subject.DistL().MakeCdf()1367PrintPrediction(cdf_l, actual_l)13681369subject.MakeFigures()13701371num_reads = 4001372curves = subject.RunSimulations(100, num_reads)1373root = 'species-rare-%s' % subject.code1374PlotCurves(curves, root=root)13751376num_reads = 8001377curves = subject.RunSimulations(500, num_reads)1378ks = [100, 200, 400, 800]1379cdfs = MakeConditionals(curves, ks)1380root = 'species-cond-%s' % subject.code1381PlotConditionals(cdfs, root=root)13821383num_reads = 10001384curves = subject.RunSimulations(500, num_reads, frac_flag=True)1385ks = [10, 100, 200, 400, 600, 800, 1000]1386cdfs = MakeFracCdfs(curves, ks)1387root = 'species-frac-%s' % subject.code1388PlotFracCdfs(cdfs, root=root)138913901391def PrintSummary(subject):1392"""Print a summary of a subject.13931394subject: Subject1395"""1396print subject.code1397print 'found %d species in %d reads' % (subject.num_species,1398subject.num_reads)13991400print 'total %d species in %d reads' % (subject.total_species,1401subject.total_reads)14021403cdf = subject.suite.DistN().MakeCdf()1404print 'n'1405PrintPrediction(cdf, 'unknown')140614071408def PrintPrediction(cdf, actual):1409"""Print a summary of a prediction.14101411cdf: predictive distribution1412actual: actual value1413"""1414median = cdf.Percentile(50)1415low, high = cdf.CredibleInterval(75)14161417print 'predicted %0.2f (%0.2f %0.2f)' % (median, low, high)1418print 'actual', actual141914201421def RandomSeed(x):1422"""Initialize random.random and numpy.random.14231424x: int seed1425"""1426random.seed(x)1427numpy.random.seed(x)142814291430def GenerateFakeSample(n, r, tr, conc=1):1431"""Generates fake data with the given parameters.14321433n: number of species1434r: number of reads in subsample1435tr: total number of reads1436conc: concentration parameter14371438Returns: hist of all reads, hist of subsample, prev_unseen1439"""1440# generate random prevalences1441dirichlet = thinkbayes.Dirichlet(n, conc=conc)1442prevalences = dirichlet.Random()1443prevalences.sort()14441445# generate a simulated sample1446pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))1447cdf = pmf.MakeCdf()1448sample = cdf.Sample(tr)14491450# collect the species counts1451hist = thinkbayes.MakeHistFromList(sample)14521453# extract a subset of the data1454if tr > r:1455random.shuffle(sample)1456subsample = sample[:r]1457subhist = thinkbayes.MakeHistFromList(subsample)1458else:1459subhist = hist14601461# add up the prevalence of unseen species1462prev_unseen = 01463for species, prev in enumerate(prevalences):1464if species not in subhist:1465prev_unseen += prev14661467return hist, subhist, prev_unseen146814691470def PlotActualPrevalences():1471"""Makes a plot comparing actual prevalences with a model.1472"""1473# read data1474subject_map, _ = ReadCompleteDataset()14751476# for subjects with more than 50 species,1477# PMF of max prevalence, and PMF of max prevalence1478# generated by a simulation1479pmf_actual = thinkbayes.Pmf()1480pmf_sim = thinkbayes.Pmf()14811482# concentration parameter used in the simulation1483conc = 0.0614841485for code, subject in subject_map.iteritems():1486prevalences = subject.GetPrevalences()1487m = len(prevalences)1488if m < 2:1489continue14901491actual_max = max(prevalences)1492print code, m, actual_max14931494# incr the PMFs1495if m > 50:1496pmf_actual.Incr(actual_max)1497pmf_sim.Incr(SimulateMaxPrev(m, conc))14981499# plot CDFs for the actual and simulated max prevalence1500cdf_actual = pmf_actual.MakeCdf(name='actual')1501cdf_sim = pmf_sim.MakeCdf(name='sim')15021503thinkplot.Cdfs([cdf_actual, cdf_sim])1504thinkplot.Show()150515061507def ScatterPrevalences(ms, actual):1508"""Make a scatter plot of actual prevalences and expected values.15091510ms: sorted sequence of in m (number of species)1511actual: sequence of actual max prevalence1512"""1513for conc in [1, 0.5, 0.2, 0.1]:1514expected = [ExpectedMaxPrev(m, conc) for m in ms]1515thinkplot.Plot(ms, expected)15161517thinkplot.Scatter(ms, actual)1518thinkplot.Show(xscale='log')151915201521def SimulateMaxPrev(m, conc=1):1522"""Returns random max prevalence from a Dirichlet distribution.15231524m: int number of species1525conc: concentration parameter of the Dirichlet distribution15261527Returns: float max of m prevalences1528"""1529dirichlet = thinkbayes.Dirichlet(m, conc)1530prevalences = dirichlet.Random()1531return max(prevalences)153215331534def ExpectedMaxPrev(m, conc=1, iters=100):1535"""Estimate expected max prevalence.15361537m: number of species1538conc: concentration parameter1539iters: how many iterations to run15401541Returns: expected max prevalence1542"""1543dirichlet = thinkbayes.Dirichlet(m, conc)15441545t = []1546for _ in range(iters):1547prevalences = dirichlet.Random()1548t.append(max(prevalences))15491550return numpy.mean(t)155115521553class Calibrator(object):1554"""Encapsulates the calibration process."""15551556def __init__(self, conc=0.1):1557"""1558"""1559self.conc = conc15601561self.ps = range(10, 100, 10)1562self.total_n = numpy.zeros(len(self.ps))1563self.total_q = numpy.zeros(len(self.ps))1564self.total_l = numpy.zeros(len(self.ps))15651566self.n_seq = []1567self.q_seq = []1568self.l_seq = []15691570def Calibrate(self, num_runs=100, n_low=30, n_high=400, r=400, tr=1200):1571"""Runs calibrations.15721573num_runs: how many runs1574"""1575for seed in range(num_runs):1576self.RunCalibration(seed, n_low, n_high, r, tr)15771578self.total_n *= 100.0 / num_runs1579self.total_q *= 100.0 / num_runs1580self.total_l *= 100.0 / num_runs15811582def Validate(self, num_runs=100, clean_param=0):1583"""Runs validations.15841585num_runs: how many runs1586"""1587subject_map, _ = ReadCompleteDataset(clean_param=clean_param)15881589i = 01590for match in subject_map.itervalues():1591if match.num_reads < 400:1592continue1593num_reads = 10015941595print 'Validate', match.code1596subject = match.Resample(num_reads)1597subject.Match(match)15981599n_actual = None1600q_actual = subject.prev_unseen1601l_actual = subject.total_species - subject.num_species1602self.RunSubject(subject, n_actual, q_actual, l_actual)16031604i += 11605if i == num_runs:1606break16071608self.total_n *= 100.0 / num_runs1609self.total_q *= 100.0 / num_runs1610self.total_l *= 100.0 / num_runs16111612def PlotN(self, root='species-n'):1613"""Makes a scatter plot of simulated vs actual prev_unseen (q).1614"""1615xs, ys = zip(*self.n_seq)1616if None in xs:1617return16181619high = max(xs+ys)16201621thinkplot.Plot([0, high], [0, high], color='gray')1622thinkplot.Scatter(xs, ys)1623thinkplot.Save(root=root,1624xlabel='Actual n',1625ylabel='Predicted')16261627def PlotQ(self, root='species-q'):1628"""Makes a scatter plot of simulated vs actual prev_unseen (q).1629"""1630thinkplot.Plot([0, 0.2], [0, 0.2], color='gray')1631xs, ys = zip(*self.q_seq)1632thinkplot.Scatter(xs, ys)1633thinkplot.Save(root=root,1634xlabel='Actual q',1635ylabel='Predicted')16361637def PlotL(self, root='species-n'):1638"""Makes a scatter plot of simulated vs actual l.1639"""1640thinkplot.Plot([0, 20], [0, 20], color='gray')1641xs, ys = zip(*self.l_seq)1642thinkplot.Scatter(xs, ys)1643thinkplot.Save(root=root,1644xlabel='Actual l',1645ylabel='Predicted')16461647def PlotCalibrationCurves(self, root='species5'):1648"""Plots calibration curves"""1649print self.total_n1650print self.total_q1651print self.total_l16521653thinkplot.Plot([0, 100], [0, 100], color='gray', alpha=0.2)16541655if self.total_n[0] >= 0:1656thinkplot.Plot(self.ps, self.total_n, label='n')16571658thinkplot.Plot(self.ps, self.total_q, label='q')1659thinkplot.Plot(self.ps, self.total_l, label='l')16601661thinkplot.Save(root=root,1662axis=[0, 100, 0, 100],1663xlabel='Ideal percentages',1664ylabel='Predictive distributions',1665formats=FORMATS,1666)16671668def RunCalibration(self, seed, n_low, n_high, r, tr):1669"""Runs a single calibration run.16701671Generates N and prevalences from a Dirichlet distribution,1672then generates simulated data.16731674Runs analysis to get the posterior distributions.1675Generates calibration curves for each posterior distribution.16761677seed: int random seed1678"""1679# generate a random number of species and their prevalences1680# (from a Dirichlet distribution with alpha_i = conc for all i)1681RandomSeed(seed)1682n_actual = random.randrange(n_low, n_high+1)16831684hist, subhist, q_actual = GenerateFakeSample(1685n_actual,1686r,1687tr,1688self.conc)16891690l_actual = len(hist) - len(subhist)1691print 'Run low, high, conc', n_low, n_high, self.conc1692print 'Run r, tr', r, tr1693print 'Run n, q, l', n_actual, q_actual, l_actual16941695# extract the data1696data = [count for species, count in subhist.Items()]1697data.sort()1698print 'data', data16991700# make a Subject and process1701subject = Subject('simulated')1702subject.num_reads = r1703subject.total_reads = tr17041705for species, count in subhist.Items():1706subject.Add(species, count)1707subject.Done()17081709self.RunSubject(subject, n_actual, q_actual, l_actual)17101711def RunSubject(self, subject, n_actual, q_actual, l_actual):1712"""Runs the analysis for a subject.17131714subject: Subject1715n_actual: number of species1716q_actual: prevalence of unseen species1717l_actual: number of new species1718"""1719# process and make prediction1720subject.Process(conc=self.conc, iters=100)1721subject.MakeQuickPrediction()17221723# extract the posterior suite1724suite = subject.suite17251726# check the distribution of n1727pmf_n = suite.DistN()1728print 'n'1729self.total_n += self.CheckDistribution(pmf_n, n_actual, self.n_seq)17301731# check the distribution of q1732pmf_q = suite.DistQ()1733print 'q'1734self.total_q += self.CheckDistribution(pmf_q, q_actual, self.q_seq)17351736# check the distribution of additional species1737pmf_l = subject.DistL()1738print 'l'1739self.total_l += self.CheckDistribution(pmf_l, l_actual, self.l_seq)17401741def CheckDistribution(self, pmf, actual, seq):1742"""Checks a predictive distribution and returns a score vector.17431744pmf: predictive distribution1745actual: actual value1746seq: which sequence to append (actual, mean) onto1747"""1748mean = pmf.Mean()1749seq.append((actual, mean))17501751cdf = pmf.MakeCdf()1752PrintPrediction(cdf, actual)17531754sv = ScoreVector(cdf, self.ps, actual)1755return sv175617571758def ScoreVector(cdf, ps, actual):1759"""Checks whether the actual value falls in each credible interval.17601761cdf: predictive distribution1762ps: percentages to check (0-100)1763actual: actual value17641765Returns: numpy array of 0, 0.5, or 11766"""1767scores = []1768for p in ps:1769low, high = cdf.CredibleInterval(p)1770score = Score(low, high, actual)1771scores.append(score)17721773return numpy.array(scores)177417751776def Score(low, high, n):1777"""Score whether the actual value falls in the range.17781779Hitting the posts counts as 0.5, -1 is invalid.17801781low: low end of range1782high: high end of range1783n: actual value17841785Returns: -1, 0, 0.5 or 11786"""1787if n is None:1788return -11789if low < n < high:1790return 11791if n == low or n == high:1792return 0.51793else:1794return 0179517961797def FakeSubject(n=300, conc=0.1, num_reads=400, prevalences=None):1798"""Makes a fake Subject.17991800If prevalences is provided, n and conc are ignored.18011802n: number of species1803conc: concentration parameter1804num_reads: number of reads1805prevalences: numpy array of prevalences (overrides n and conc)1806"""1807# generate random prevalences1808if prevalences is None:1809dirichlet = thinkbayes.Dirichlet(n, conc=conc)1810prevalences = dirichlet.Random()1811prevalences.sort()18121813# generate a simulated sample1814pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))1815cdf = pmf.MakeCdf()1816sample = cdf.Sample(num_reads)18171818# collect the species counts1819hist = thinkbayes.MakeHistFromList(sample)18201821# extract the data1822data = [count for species, count in hist.Items()]1823data.sort()18241825# make a Subject and process1826subject = Subject('simulated')18271828for species, count in hist.Items():1829subject.Add(species, count)1830subject.Done()18311832return subject183318341835def PlotSubjectCdf(code=None, clean_param=0):1836"""Checks whether the Dirichlet model can replicate the data.1837"""1838subject_map, uber_subject = ReadCompleteDataset(clean_param=clean_param)18391840if code is None:1841subjects = subject_map.values()1842subject = random.choice(subjects)1843code = subject.code1844elif code == 'uber':1845subject = uber_subject1846else:1847subject = subject_map[code]18481849print subject.code18501851m = subject.GetM()18521853subject.Process(high=m, conc=0.1, iters=0)1854print subject.suite.params[:m]18551856# plot the cdf1857options = dict(linewidth=3, color='blue', alpha=0.5)1858cdf = subject.MakeCdf()1859thinkplot.Cdf(cdf, **options)18601861options = dict(linewidth=1, color='green', alpha=0.5)18621863# generate fake subjects and plot their CDFs1864for _ in range(10):1865prevalences = subject.suite.SamplePrevalences(m)1866fake = FakeSubject(prevalences=prevalences)1867cdf = fake.MakeCdf()1868thinkplot.Cdf(cdf, **options)18691870root = 'species-cdf-%s' % code1871thinkplot.Save(root=root,1872xlabel='rank',1873ylabel='CDF',1874xscale='log',1875formats=FORMATS,1876)187718781879def RunCalibration(flag='cal', num_runs=100, clean_param=50):1880"""Runs either the calibration or validation process.18811882flag: string 'cal' or 'val'1883num_runs: how many runs1884clean_param: parameter used for data cleaning1885"""1886cal = Calibrator(conc=0.1)18871888if flag == 'val':1889cal.Validate(num_runs=num_runs, clean_param=clean_param)1890else:1891cal.Calibrate(num_runs=num_runs)18921893cal.PlotN(root='species-n-%s' % flag)1894cal.PlotQ(root='species-q-%s' % flag)1895cal.PlotL(root='species-l-%s' % flag)1896cal.PlotCalibrationCurves(root='species5-%s' % flag)189718981899def RunTests():1900"""Runs calibration code and generates some figures."""1901RunCalibration(flag='val')1902RunCalibration(flag='cal')19031904PlotSubjectCdf('B1558.G', clean_param=50)1905PlotSubjectCdf(None)190619071908def main(script):1909RandomSeed(17)1910RunSubject('B1242', conc=1, high=100)19111912RandomSeed(17)1913SimpleDirichletExample()19141915RandomSeed(17)1916HierarchicalExample()191719181919if __name__ == '__main__':1920main(*sys.argv)192119221923