from MCMC to Variational Inference

In which this became an exercise in deriving the closed form gaussian KL expression from Auto-Encoding Variational Bayes!

Almost all observed data is the result of some process with hidden latent factors. Bayesian analysis provides a recipe for learning from that data \(x\) and the unknown latent variables \(z\).

1. Specify a prior \(p(z)\) quantifying what is known about \(z\) before any data is observed
2. Learn a likelihood function \(p(x \mid z)\), or decoder
3. Apply Bayes’ rule \(p(z|x) = \frac{p(x|z)p(z)}{\int_z p(x|z)p(z) dz}\) to learn the posterior distribution, which describes what is known about \(z\) after observing the data \(x\)

The issue is that computing \(p(z \mid x)\) is not feasible for a number of reasons because marginalization in the denominator becomes computationally intractable when the variables become high dimensional. In other words, uncertainty is expensive, and approximate inference methods all get at finding \(p(x)\) without integrating over all configurations of latents.

Markov Chain Monte Carlo (MCMC) methods and Variational Bayesian methods differ in whether they explicitly model \(p(z \mid x)\), the decoder or recognition model.

The MCMC paradigm is to sample \(z_0 \sim p(z)\) and apply a transition operator \(q(z_t \mid z_{t-1}, x)\) until \(z_T\) is a random variable which converges to the posterior \(p(z \mid x)\). The VAE methodology parameterizes and learns \(p(z \mid x)\).

One particular version of MCMC is the Metropolis-Hastings algorithm. The Metropolis-Hastings algorithm can be seen as a random walk which gets closer to \(p(z \mid x)\). The initial sample is drawn from a standard Gaussian and the particle moves to another state (generated by adding noise to the current sample) depending on whether the new state is higher in probability. In the implementation for this blog post, the sample is decoded and the reconstructed data is compared to the actual data \(x\). After a given number of steps, the binary cross entropy loss between the current reconstruction and the data point is backpropagated to train the decoder.

The major inefficiency in Metropolis Hastings is the random walk, which uses the same transition matrix throughout the algorithm. One wonders, after seeing some training data, isn’t there a more efficient way of approximating the latent posterior?

Among other methods which use information from training to guide sampling, variational inference does so by explicitly learning a parameterization of the encoder \(q(z \mid x)\), also called a recognition model. Assuming the encoder is a Gaussian, we can model it with a neural network which outputs \(\mu, \sigma^2\) given \(x\) as input.

Starting with the expression for the KL divergence between the learned model \(q_{\phi}(z \mid x)\) and the true posterior \(p(z \mid x)\), we can derive the evidence lower bound (ELBO) which is a lower bound on \(p(x)\) and an objective that through maximizing we can obtain an estimate for \(p(x)\), converting the inference problem into an optimization problem.

The derivation is

\[\begin{align*} D_{KL}(q_{\phi}(z|x) \| p(z | x)) &= \int_z q_{\phi}(z|x) \log \frac{q_{\phi}(z|x)}{p(z|x)} dz \\ &= - \int_z q_{\phi}(z|x) \log \frac{p(z|x)}{q_{\phi}(z|x)} dz \\ &= - \int_z q_{\phi}(z|x) \log \frac{p(z,x)}{q_{\phi}(z|x)p(x)} dz \\ &= - \left(\int_z q_{\phi}(z|x) \log \frac{p(z,x)}{q_{\phi}(z|x)} dz - \int_z q_{\phi}(z|x) \log p(x) dz \right) \\ &= - \int_z q_{\phi}(z|x) \log \frac{p(z,x)}{q_{\phi}(z|x)} dz + \log p(x) \end{align*}\]

So

\[\log p(x) = \mathcal{L} + D_{KL}(q_{\phi}(z|x) \| p(z | x))\]

Since \(D_{KL}\) is non-negative, \(\mathcal{L}\) is a lower bound on \(\log p(x)\).

We can then write \(\mathcal{L}\), the ELBO, as

\[\begin{align*} \mathcal{L} &= \int_z q_{\phi}(z|x) \log \frac{p_{\theta}(z,x)}{q_{\phi}(z|x)} dz \\ &= \int_z q_{\phi}(z|x) \log \frac{p_{\theta}(x|z) p(z)}{q_{\phi}(z|x)} dz\\ &= \mathbb{E}_{z \sim q_{\phi}(z|x)}[p_{\theta}(x|z)] - D_{KL} (q_{\phi}(z|x) \| p(z)) \end{align*}\]

Analytic Integral of the KL Divergence of two Gaussians

The ELBO contains a \(- D_{KL} (q_{\phi}(z|x) \| p(z))\) term. We can integrate this expression analytically with a combination of algebra, properties of integrating probability distributions and the trace trick for expectations of quadratic forms.

First, let’s be explicit about the expressions for two PDFs:

\[q(z) = \mathcal{N}(z; \mu, \sigma^2) = \frac{1}{\sqrt{(2\pi)^J |\Sigma|}} \exp\left(-\frac{1}{2} (z - \mu)^T \Sigma^{-1} (z - \mu)\right)\] \[p(z) = \mathcal{N}(z; 0, 1) = \frac{1}{\sqrt{(2\pi)^J |I|}} \exp\left(-\frac{1}{2} (z - \mu)^T I (z - \mu)\right)\]

Multivariate Gaussian Facts

1. \(\text{Cov}(z) = \mathbb{E}_{z}[(z - \mu)(z - \mu)^T] = \Sigma\) where \(\Sigma\) is the covariance matrix
2. \(\mathbb{E}[z] = \mu\)

Trace Trick for Expectations of Quadratic Forms

Let \((z-\mu)^T A (z - \mu)\) be the quadratic form.

1. A quadratic form is a scalar, so it is its own trace \(\mathbb{E}[(z - \mu)^T A (z - \mu)] = \mathbb{E}[\text{tr}((z - \mu)^T A (z - \mu))]\)

2. Cyclic property of trace means \(\text{tr}(ABC) = \text{tr}(BCA) = \text{tr}(CAB)\)

3. Linearity of expectation through the trace operator

The Derivation of Closed Form KL Divergence of Two Gaussians

The overall structure was:
1. Notice that \(D_{KL} (q(z|x) \| p(z)) = \int q(z)(\log p(z) - \log q(z)) dz\) which decomposes to \(\int q(z) \log p(z) dz\) and \(\int q(z) \log q(z) dz\)
2. Compute \(\int q(z) \log p(z) dz\) and \(\int q(z) \log q(z) dz\) separately and add them back together

To calculate \(\int q(z) \log q(z) dz\):
1. Simplify \(\log q(z)\)
2. Distribute \(\int q(z)\)

First, write the \(\log\) of \(q(z)\):

\[\log q(z) = \log \left(\frac{1}{\sqrt{(2\pi)^J \prod_{j=1}^J \sigma_j^2}}\right) - \frac{1}{2}(z - \mu)^T \Sigma^{-1}(z - \mu)\]

Now, distribute \(\int q(z)\):

\[\begin{align*} \int_{z} q(z) \log q(z) dz &= \log \left(\frac{1}{\sqrt{(2\pi)^J \prod_{j=1}^J \sigma_j^2}}\right) \int_{z} q(z) dz - \frac{1}{2} \mathbb{E}_{z} [(z - \mu)^T \Sigma^{-1} (z - \mu)] \\ &= \log 1 - \log \left((2\pi)^{J/2}(\prod_{j=1}^J \sigma_{j}^2)^{1/2}\right) - \frac{1}{2} \mathbb{E}_{z} [\text{tr}(\Sigma^{-1} (z - \mu) (z - \mu)^T)] \\ &= - \frac{J}{2} \log(2\pi) - \frac{1}{2} \sum_{j=1}^{J} \log \sigma_{j}^2 - \frac{1}{2} \text{tr}(\mathbb{E}_z[\Sigma^{-1} \Sigma]) \\ &= - \frac{J}{2} \log(2\pi) - \frac{1}{2} \sum_{j=1}^{J} \log \sigma_{j}^2 - \frac{J}{2} \\ &= -\frac{J}{2} \log(2\pi) - \frac{1}{2} \sum_{j=1}^J (1 + \log \sigma_{j}^2) \end{align*}\]

The computation of \(\int q(z) \log p(z) dz\) follows a similar approach!