Auto-Encoding Variational Bayes

Diederik P Kingma, Max Welling

ICLR 2014 | arxiv |

Reminder: Bayes theorem

\[ P(H|D) = \frac{P(D|H)P(H)}{P(D)}\] P (H), ==the prior==, is the initial degree of belief in H. P (H|D), the ==posterior==, is the degree of belief in the hypothesis H having accounted for D. the quotient P(D|H)/P(D) represents the support D provides for H.

Objective and motivation: Scalable inference and learning in Directed probabilistic graphical models in the presence of a continuous latent variable \(z\) with intractable posterior.

Contributions:

  1. A reparametrization of the variational lower bound into an unbiased and differentiable estimator, \(\equiv\) stochastic gradient variational Bayes (SGVB).
  2. Posterior inference with an approximate model fitted to the intractable poster using the lower bound (in the case of iid dataset and continuous latent variable per data point), \(\equiv\) Auto encoding VB (AEVB).

Typical graphical models:

images Solid lines for the generative model \(p\_\theta(z)p\_\theta(x|z)\) and dashed lines for the variational approximation \(q\_\phi(z|x)\) to the intractable posterior \(p\_\theta(z|x)\) The recognition model \(q\_\phi(z|x)\) is a probabilistic encoder and \(p\_\theta(x|z)\) is a probablistic decoder. #### Setting: Given an N iid samples \(\mathcal X = \\{x\_i\\}\_{1\leq i \leq N}\) of some continuous or discrete variable \(x\). We assume the data is generated by some random process involving an unobserved, continuoius random latent variable \(z\). The process consists of two steps:

Both the prior and likelihood come from parametric families of distributions \(\theta \in \Theta\) with PDFs that are differentiable a.s. w.r.t both \(\theta\) and \(z\)

The variational lower bound:

\[ML(\mathcal X) = \sum\_{i=1}^N \log p\_\theta(x\_i)\] with:

\[\log p_\theta(x) = \log(p_\theta(x|z)) + \log(p_\theta(z)) - \log(p_\theta(z|x))\]

Given that:

\[D_{KL}(q_\phi(z|x)||p_\theta(z|x)) = E_{q_\phi(z|x)}\left[\log(q_\phi(z|x)) - \log(p_\theta(z|x))\right]\] \[D_{KL}(q_\phi(z|x)||p_\theta(z)) = E_{q_\phi(z|x)}\left[\log(q_\phi(z|x)) - \log(p_\theta(z))\right]\]

We can rewrite the ML of a sample \(x\) as:

\[\log p_\theta(x) = D_{KL}(q_\phi(z|x)||p_\theta(z|x)) - D_{KL}(q_\phi(z|x)||p_\theta(z)) + E_{q_\phi(z|x)}\left[\log(p_\theta(x|z)\right] \]

Since the KL divergence is nonnegative we set \(\mathcal L(\theta, \phi;x)\) as the variational lower bound:

\[\mathcal L(\theta, \phi;x) = - D_{KL}(q_\phi(z|x)||p_\theta(z)) + E_{q_\phi(z|x)}\left[\log(p_\theta(x|z)\right] \]

The gradient of this variational lower bound, as is, w.r.t to the recognition/variational parameters \(\phi\) is problematic and is usually estimated via MC.

SGVB:

For a chosen approximate posterior \(q\_\phi(z|x)\) we can reparameterize the latent variable \(z\) using a differentiable transformation \(\tilde z = g\_\phi(\epsilon, x)\) where \(\epsilon \sim p(\epsilon)\) is an auxiliary noise variable. Now the MC estimation of the expectation of some function \(f(z)\) w.r.t to \(q\_\phi(z|x)\) is:

\[\begin{align} E_{q_\phi(z|x)}\left[f(z)\right] & = E_{p(\epsilon)}\left[f(g_\phi(\epsilon, x))\right]\\ & \approx \frac{1}{L_{MC}} \sum_l f(g_\phi(\epsilon_l, x)),\phantom{abcd} \epsilon_l \sim p(\epsilon) \end{align}\]

Applied to the lower bound \(\mathcal L\) we build the SGVB estimator denoted \(\tilde{\mathcal L}\)

\[\tilde{\mathcal L}(\theta, \phi;x) = - D_{KL}(q_\phi(z|x)||p_\theta(z)) + \frac{1}{L_{MC}} \sum_l \log(p_\theta(x|z_l)) \]

where \(z\_l = g\_\phi(\epsilon\_l, x)\) and \(\epsilon\_l\sim p(\epsilon)\) The KL divergence can be interpreted as a regularization term for \(\phi\) encouraging the approximate posterior to be close to the prior.

This term can be integrated analytically (in most cases) and only the reconstruction error needs MC sampling.

e.g.in the case where both the prior and approximate posterior are Gaussians: \(p\_\theta(z) = \mathcal N(0,I)\) and \(q\_\phi(z|x) = \mathcal N(\mu, \sigma^2)\) with \(J\) the dimension of \(z\):

\[ - D_{KL}(q_\phi(z|x)||p_\theta(z)) = \frac{1}{2} \sum_j^J (1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2) \]

Note: For large batch sizes in SGD, we can set \(L\_{MC} = 1\).

AEVB:

Use an SGD or adagrad to optimize the variational bound \(\tilde{\mathcal L}\).