Sample from Antoniak Distribution with Python.

rand_antoniak draws a sample from the distribution of tables created by a Chinese restaurant process with parameter alpha after n patrons are seated. Some notes on this distribution are here: http://www.cs.cmu.edu/~tss/antoniak.pdf.

import numpy as np
from numpy.random import choice

def stirling(N, m):
    if N < 0 or m < 0:
        raise Exception("Bad input to stirling.")
    if m == 0 and N > 0:
        return 0
    elif (N, m) == (0, 0):
        return 1
    elif N == 0 and m > 0:
        return m
    elif m > N:
        return 0
    else:
        return stirling(N-1, m-1) + (N-1) * stirling(N-1, m)

assert stirling(9, 3) == 118124
assert stirling(9, 3) == 118124
assert stirling(0, 0) == 1
assert stirling(1, 1) == 1
assert stirling(2, 9) == 0
assert stirling(9, 6) == 4536

def normalized_stirling_numbers(nn):
    #  * stirling(nn) Gives unsigned Stirling numbers of the first
    #  * kind s(nn,*) in ss. ss[i] = s(nn,i). ss is normalized so that maximum
    #  * value is 1. After Teh (npbayes).
    ss = [stirling(nn, i) for i in range(1, nn + 1)]
    max_val = max(ss)
    return np.array(ss, dtype=float) / max_val

ss1 = np.array([1])
ss2 = np.array([1, 1])
ss10 = np.array([  3.09439754e-01,   8.75395242e-01,   1.00000000e+00,
         6.17105824e-01,   2.29662318e-01,   5.39549757e-02,
         8.05832694e-03,   7.41877718e-04,   3.83729854e-05,
         8.52733009e-07]) # Verified with Yee Whye Teh's code

assert np.sqrt(((normalized_stirling_numbers(1) - ss1)**2).sum()) < 0.00001
assert np.sqrt(((normalized_stirling_numbers(2) - ss2)**2).sum()) < 0.00001
assert np.sqrt(((normalized_stirling_numbers(10) - ss10)**2).sum()) < 0.00001



def rand_antoniak(alpha, n):
    # Sample from Antoniak Distribution
    # cf http://www.cs.cmu.edu/~tss/antoniak.pdf
    p = normalized_stirling_numbers(n)
    aa = 1
    for i, _ in enumerate(p):
        p[i] *= aa
        aa *= alpha
    p = np.array(p) / np.array(p).sum()
    return choice(range(1, n+1), p=p)

rand_antoniak(.5, 10)