Nonparametric Latent Dirichlet Allocation

In [1]:
%matplotlib inline
%precision 2
Out[1]:
u'%.2f'

Nonparametric Latent Dirichlet Allocation

Latent Dirichlet Allocation is a generative model for topic modeling. Given a collection of documents, an LDA inference algorithm attempts to determined (in an unsupervised manner) the topics discussed in the documents. It makes the assumption that each document is generated by a probability model, and, when doing inference, we try to find the parameters that best fit the model (as well as unseen/latent variables generated by the model). If you are unfamiliar with LDA, Edwin Chen has a friendly introduction you should read.

Because LDA is a generative model, we can simulate the construction of documents by forward-sampling from the model. The generative algorithm is as follows (following Heinrich):

  • for each topic $k\in [1,K]$ do
    • sample term distribution for topic $\overrightarrow \phi_k \sim \text{Dir}(\overrightarrow \beta)$
  • for each document $m\in [1, M]$ do
    • sample topic distribution for document $\overrightarrow\theta_m\sim \text{Dir}(\overrightarrow\alpha)$
    • sample document length $N_m\sim\text{Pois}(\xi)$
    • for all words $n\in [1, N_m]$ in document $m$ do
      • sample topic index $z_{m,n}\sim\text{Mult}(\overrightarrow\theta_m)$
      • sample term for word $w_{m,n}\sim\text{Mult}(\overrightarrow\phi_{z_{m,n}})$

You can implement this with a little bit of code and start to simulate documents.

In LDA, we assume each word in the document is generated by a two-step process:

  1. Sample a topic from the topic distribution for the document.
  2. Sample a word from the term distribution from the topic.

When we fit the LDA model to a given text corpus with an inference algorithm, our primary objective is to find the set of topic distributions $\underline \Theta$, term distributions $\underline \Phi$ that generated the documents, and latent topic indices $z_{m,n}$ for each word.

To run the generative model, we need to specify each of these parameters:

In [2]:
vocabulary = ['see', 'spot', 'run']
num_terms = len(vocabulary)
num_topics = 2 # K
num_documents = 5 # M
mean_document_length = 5 # xi
term_dirichlet_parameter = 1 # beta
topic_dirichlet_parameter = 1 # alpha

The term distribution vector $\underline\Phi$ is a collection of samples from a Dirichlet distribution. This describes how our 3 terms are distributed across each of the two topics.

In [3]:
from scipy.stats import dirichlet, poisson
from numpy import round
from collections import defaultdict
from random import choice as stl_choice
In [4]:
term_dirichlet_vector = num_terms * [term_dirichlet_parameter]
term_distributions = dirichlet(term_dirichlet_vector, 2).rvs(size=num_topics)
print(term_distributions)
[[ 0.41  0.02  0.57]
 [ 0.38  0.36  0.26]]

Each document corresponds to a categorical distribution across this distribution of topics (in this case, a 2-dimensional categorical distribution). This categorical distribution is a distribution of distributions; we could look at it as a Dirichlet process!

The base base distribution of our Dirichlet process is a uniform distribution of topics (remember, topics are term distributions).

In [5]:
base_distribution = lambda: stl_choice(term_distributions)
# A sample from base_distribution is a distribution over terms
# Each of our two topics has equal probability
from collections import Counter
for topic, count in Counter([tuple(base_distribution()) for _ in range(10000)]).most_common():
    print("count:", count, "topic:", [round(prob, 2) for prob in topic])
count: 5066 topic: [0.40999999999999998, 0.02, 0.56999999999999995]
count: 4934 topic: [0.38, 0.35999999999999999, 0.26000000000000001]

Recall that a sample from a Dirichlet process is a distribution that approximates (but varies from) the base distribution. In this case, a sample from the Dirichlet process will be a distribution over topics that varies from the uniform distribution we provided as a base. If we use the stick-breaking metaphor, we are effectively breaking a stick one time and the size of each portion corresponds to the proportion of a topic in the document.

To construct a sample from the DP, we need to again define our DP class:

In [6]:
from scipy.stats import beta
from numpy.random import choice

class DirichletProcessSample():
    def __init__(self, base_measure, alpha):
        self.base_measure = base_measure
        self.alpha = alpha
        
        self.cache = []
        self.weights = []
        self.total_stick_used = 0.

    def __call__(self):
        remaining = 1.0 - self.total_stick_used
        i = DirichletProcessSample.roll_die(self.weights + [remaining])
        if i is not None and i < len(self.weights) :
            return self.cache[i]
        else:
            stick_piece = beta(1, self.alpha).rvs() * remaining
            self.total_stick_used += stick_piece
            self.weights.append(stick_piece)
            new_value = self.base_measure()
            self.cache.append(new_value)
            return new_value
      
    @staticmethod 
    def roll_die(weights):
        if weights:
            return choice(range(len(weights)), p=weights)
        else:
            return None

For each document, we will draw a topic distribution from the Dirichlet process:

In [7]:
topic_distribution = DirichletProcessSample(base_measure=base_distribution, 
                                            alpha=topic_dirichlet_parameter)

A sample from this topic distribution is a distribution over terms. However, unlike our base distribution which returns each term distribution with equal probability, the topics will be unevenly weighted.

In [8]:
for topic, count in Counter([tuple(topic_distribution()) for _ in range(10000)]).most_common():
    print("count:", count, "topic:", [round(prob, 2) for prob in topic])
count: 9589 topic: [0.38, 0.35999999999999999, 0.26000000000000001]
count: 411 topic: [0.40999999999999998, 0.02, 0.56999999999999995]

To generate each word in the document, we draw a sample topic from the topic distribution, and then a term from the term distribution (topic).

In [9]:
topic_index = defaultdict(list)
documents = defaultdict(list)

for doc in range(num_documents):
    topic_distribution_rvs = DirichletProcessSample(base_measure=base_distribution, 
                                                    alpha=topic_dirichlet_parameter)
    document_length = poisson(mean_document_length).rvs()
    for word in range(document_length):
        topic_distribution = topic_distribution_rvs()
        topic_index[doc].append(tuple(topic_distribution))
        documents[doc].append(choice(vocabulary, p=topic_distribution))

Here are the documents we generated:

In [10]:
for doc in documents.values():
    print(doc)
['see', 'run', 'see', 'spot', 'see', 'spot']
['see', 'run', 'see']
['see', 'run', 'see', 'see', 'run', 'spot', 'spot']
['run', 'run', 'run', 'spot', 'run']
['run', 'run', 'see', 'spot', 'run', 'run']

We can see how each topic (term-distribution) is distributed across the documents:

In [11]:
for i, doc in enumerate(Counter(term_dist).most_common() for term_dist in topic_index.values()):
    print("Doc:", i)
    for topic, count in doc:
        print(5*" ", "count:", count, "topic:", [round(prob, 2) for prob in topic])
Doc: 0
      count: 6 topic: [0.38, 0.35999999999999999, 0.26000000000000001]
Doc: 1
      count: 3 topic: [0.40999999999999998, 0.02, 0.56999999999999995]
Doc: 2
      count: 5 topic: [0.40999999999999998, 0.02, 0.56999999999999995]
      count: 2 topic: [0.38, 0.35999999999999999, 0.26000000000000001]
Doc: 3
      count: 5 topic: [0.38, 0.35999999999999999, 0.26000000000000001]
Doc: 4
      count: 5 topic: [0.40999999999999998, 0.02, 0.56999999999999995]
      count: 1 topic: [0.38, 0.35999999999999999, 0.26000000000000001]

To recap: for each document we draw a sample from a Dirichlet Process. The base distribution for the Dirichlet process is a categorical distribution over term distributions; we can think of the base distribution as an $n$-sided die where $n$ is the number of topics and each side of the die is a distribution over terms for that topic. By sampling from the Dirichlet process, we are effectively reweighting the sides of the die (changing the distribution of the topics).

For each word in the document, we draw a sample (a term distribution) from the distribution (over term distributions) sampled from the Dirichlet process (with a distribution over term distributions as its base measure). Each term distribution uniquely identifies the topic for the word. We can sample from this term distribution to get the word.

Given this formulation, we might ask if we can roll an infinite sided die to draw from an unbounded number of topics (term distributions). We can do exactly this with a Hierarchical Dirichlet process. Instead of the base distribution of our Dirichlet process being a finite distribution over topics (term distributions) we will instead make it an infinite Distribution over topics (term distributions) by using yet another Dirichlet process! This base Dirichlet process will have as its base distribution a Dirichlet distribution over terms.

We will again draw a sample from a Dirichlet Process for each document. The base distribution for the Dirichlet process is itself a Dirichlet process whose base distribution is a Dirichlet distribution over terms. (Try saying that 5-times fast.) We can think of this as a countably infinite die each side of the die is a distribution over terms for that topic. The sample we draw is a topic (distribution over terms).

For each word in the document, we will draw a sample (a term distribution) from the distribution (over term distributions) sampled from the Dirichlet process (with a distribution over term distributions as its base measure). Each term distribution uniquely identifies the topic for the word. We can sample from this term distribution to get the word.

These last few paragraphs are confusing! Let's illustrate with code.

In [12]:
term_dirichlet_vector = num_terms * [term_dirichlet_parameter]
base_distribution = lambda: dirichlet(term_dirichlet_vector).rvs(size=1)[0]

base_dp_parameter = 10
base_dp = DirichletProcessSample(base_distribution, alpha=base_dp_parameter)

This sample from the base Dirichlet process is our infinite sided die. It is a probability distribution over a countable infinite number of topics.

The fact that our die is countably infinite is important. The sampler base_distribution draws topics (term-distributions) from an uncountable set. If we used this as the base distribution of the Dirichlet process below each document would be constructed from a completely unique set of topics. By feeding base_distribution into a Dirichlet Process (stochastic memoizer), we allow the topics to be shared across documents.

In other words, base_distribution will never return the same topic twice; however, every topic sampled from base_dp would be sampled an infinite number of times (if we sampled from base_dp forever). At the same time, base_dp will also return an infinite number of topics. In our formulation of the the LDA sampler above, our base distribution only ever returned a finite number of topics (num_topics); there is no num_topics parameter here.

Given this setup, we can generate documents from the hierarchical Dirichlet process with an algorithm that is essentially identical to that of the original latent Dirichlet allocation generative sampler:

In [13]:
nested_dp_parameter = 10

topic_index = defaultdict(list)
documents = defaultdict(list)

for doc in range(num_documents):
    topic_distribution_rvs = DirichletProcessSample(base_measure=base_dp, 
                                                    alpha=nested_dp_parameter)
    document_length = poisson(mean_document_length).rvs()
    for word in range(document_length):
        topic_distribution = topic_distribution_rvs()
        topic_index[doc].append(tuple(topic_distribution))
        documents[doc].append(choice(vocabulary, p=topic_distribution))

Here are the documents we generated:

In [14]:
for doc in documents.values():
    print(doc)
['spot', 'spot', 'spot', 'spot', 'run']
['spot', 'spot', 'see', 'spot']
['spot', 'spot', 'spot', 'see', 'spot', 'spot', 'spot']
['run', 'run', 'spot', 'spot', 'spot', 'spot', 'spot', 'spot']
['see', 'run', 'see', 'run', 'run', 'run']

And here are the latent topics used:

In [15]:
for i, doc in enumerate(Counter(term_dist).most_common() for term_dist in topic_index.values()):
    print("Doc:", i)
    for topic, count in doc:
        print(5*" ", "count:", count, "topic:", [round(prob, 2) for prob in topic])
Doc: 0
      count: 2 topic: [0.17999999999999999, 0.79000000000000004, 0.02]
      count: 1 topic: [0.23000000000000001, 0.58999999999999997, 0.17999999999999999]
      count: 1 topic: [0.089999999999999997, 0.54000000000000004, 0.35999999999999999]
      count: 1 topic: [0.22, 0.40000000000000002, 0.38]
Doc: 1
      count: 2 topic: [0.23000000000000001, 0.58999999999999997, 0.17999999999999999]
      count: 1 topic: [0.17999999999999999, 0.79000000000000004, 0.02]
      count: 1 topic: [0.35999999999999999, 0.55000000000000004, 0.089999999999999997]
Doc: 2
      count: 4 topic: [0.11, 0.65000000000000002, 0.23999999999999999]
      count: 2 topic: [0.070000000000000007, 0.65000000000000002, 0.27000000000000002]
      count: 1 topic: [0.28999999999999998, 0.65000000000000002, 0.070000000000000007]
Doc: 3
      count: 2 topic: [0.17999999999999999, 0.79000000000000004, 0.02]
      count: 2 topic: [0.25, 0.55000000000000004, 0.20000000000000001]
      count: 2 topic: [0.28999999999999998, 0.65000000000000002, 0.070000000000000007]
      count: 1 topic: [0.23000000000000001, 0.58999999999999997, 0.17999999999999999]
      count: 1 topic: [0.089999999999999997, 0.54000000000000004, 0.35999999999999999]
Doc: 4
      count: 3 topic: [0.40000000000000002, 0.23000000000000001, 0.37]
      count: 2 topic: [0.42999999999999999, 0.17999999999999999, 0.40000000000000002]
      count: 1 topic: [0.23000000000000001, 0.29999999999999999, 0.46000000000000002]

Our documents were generated by an unspecified number of topics, and yet the topics were shared across the 5 documents. This is the power of the hierarchical Dirichlet process!

This non-parametric formulation of Latent Dirichlet Allocation was first published by Yee Whye Teh et al.

Unfortunately, forward sampling is the easy part. Fitting the model on data requires complex MCMC or variational inference. There are a limited number of implementations of HDP-LDA available, and none of them are great.