Reductions in representation learning with rate-distortion theory

In lab meeting this week, we discussed unsupervised learning in the context of deep generative models, namely \beta-variational auto-encoders (\beta-VAEs), drawing from the original, Higgins et al. 2017 (ICLR), and its follow-up, Burgess et al. 2018. The classic VAE represents a clever approach to learning highly expressive generative models, defined by a deep neural network that transforms samples from a standard normal distribution to some distribution of interest (e.g., natural images).  Technically, VAE training seeks to maximize a lower bound on the likelihood p_\theta(x) = \int p_\theta(x\mid z) p(z) dz, where p_\theta(x|z) defines the generative mapping from latents z to data x. This “evidence lower bound” (ELBO) depends on a variational approximation to the posterior, q_\phi(z\mid x), which is also parametrized by a deep neural network (the so-called “encoder”).

A crucial drawback to the classic VAE, however, is that the learned latent representations tend to lack interpretability. The \beta-VAE seeks to overcome this limitation by learning “disentangled” representations, in which single latents are sensitive to single generative factors in the data and relatively invariant to others (Bengio et al. 2013). I would call these “intuitively robust” — rotating an apple (orientation) shouldn’t make its latent representation any less red (color) or any less fruity (type). To overcome this challenge, \beta-VAEs optimize a modified ELBO given by:

\underset{\theta,\phi}{\text{maximize}}\:\:\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x\mid z)\right]-\beta D_{KL}(q_\phi(z\mid x)\Vert\, p(z))

with and standard VAEs corresponding to \beta=1. The new hyperparameter \beta controls the optimization’s tension between maximizing the data likelihood and limiting the expressiveness of the variational posterior relative to a fixed latent prior p(z)=\mathcal{N}(0,I).

Recent work has been interested in tuning the latent representations of deep generative models (Adversarial Autoencoders (Makhzani et al. 2016), InfoGANS (Chen et al. 2016), Total Correlation VAEs (Chen et al. 2019), among others), but the generalization used by \beta-VAEs in particular looked somehow familiar to me. This is because \beta-VAEs recapitulate the classical rate-distortion theory problem. This was observed briefly also in recent work by Alemi et al. 2018, but I would like to elaborate and show explicitly how \beta-VAEs are reducible to a distortion-rate minimization using deep generative models.

Rate-distortion theory is a theoretical framework for lossy data compression through a noisy channel. This fundamental problem in information theory balances the minimum permissible amount of information (in bits) transmitted across the channel, the “rate”, against the corruption of the original signal, a penalty measured by a “distortion” function d(x,z). Our terminology changes, but the fundamental problem is the same; I made that comparison as obvious as possible in the figure below.

Derivation. Given a dataset \mathcal{D} with a distribution p^*(x), define any statistical mapping q_\phi(z\mid x) that encodes x into a code z. Note that q_\phi is just an encoder, and together they induce a joint distribution p(x,z)=q_\phi(z\mid x)p^*(x) with a marginal p(z)=\int dx\, q_\phi(z\mid x)p^*(x). The distortion-rate optimization would minimize distortion d(\cdot,\cdot) subject to a maximum rate R, i.e.

\underset{q_\phi(z\mid x),p(z)}{\text{minimize}}\:\:\mathbb{E}_{p(x,z)}[d(x,z)]\:\:\text{subject to}\:\: I(x,z)\le R

\Longrightarrow \underset{q_\phi(z\mid x),p(z)}{\text{minimize}}\:\:\mathbb{E}_{p(x,z)}[d(x,z)]-\beta I(x,z)

Consider first the mutual information. We leverage a more tractable upper bound with

I(x,z)=\int dx\,p^*(x)\int dz\,q_\phi(z\mid x)\log\frac{q_\phi(z\mid x)}{p(z)}\le \int dx\,p^*(x)\int dz\,q_\phi(z\mid x)\log\frac{q_\phi(z\mid x)}{m(z)}

\text{since}\:\: D_{KL}\left(p(z)\Vert\, m(z)\right)\Longrightarrow -\int dz\,p(z)\log p(z) \le -\int dz\,p(z)\log m(z)

We’ve replaced the marginal p(z) induced by our choice of encoder q_\phi with another distribution m(z) that makes the optimization more tractable, e.g. \mathcal{N}(0,I) in the VAE. Our objective can be rewritten as

\underset{q_\phi(z\mid x),m(z)}{\text{minimize}}\:\:\mathbb{E}_{p(x,z)}[d(x,z)]-\beta\, \mathbb{E}_{x\sim\mathcal{D}}\left[D_{KL}(q_\phi(z\mid x)\Vert\, m(z))\right]

Suppose the distortion of interest is posterior density (mis)estimation, d(x,z)=-\log p_\theta(x\mid z). Such a function penalizes representations z from which we cannot regenerate an observed data vector x through the decoding network p_\theta with high probability. A typical distortion-rate problem would fix the distortion function, but we choose to learn this decoder. We can optimize the objective for each x to eliminate the outer expectation over the data \mathcal{D}, fix m(z)=\mathcal{N}(0,I), and recover the \beta-VAE objective precisely:

\underset{q_\phi(z\mid x),p_\theta(x\mid z)}{\text{maximize}}\:\:\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x\mid z)\right]-\beta\, D_{KL}(q_\phi(z\mid x)\Vert\, m(z))

When \beta> 1, our optimization prioritizes minimizing the second term (rate) over maximizing the first one (distortion). In this sense, the authors’ argument for large \beta can be reinterpreted as an argument for higher-distortion, lower-rate codes (read: latent representations) to encourage interpretability. I edited a figure below from Alemi et al. 2018 to clarify this.

Distortion (D) vs. Rate (R) as a function of free parameters in the rate-distortion problem (and \beta-VAEs) — the proposed method privileges solutions in the top-left quadrant (adapted from Alemi et al. 2018).

Information-theoretic hypotheses abound. Perhaps enforcing optimization in this region could discourage solutions that depend on learning an ultra-powerful decoder (VAE: generator) p_\theta(x\mid z), in other words solutions that depend on a good code, not necessarily a good decode. Does eliminating this possibility simply make room to fish out an ad-hoc interpretable representation, or is there a more sophisticated explanation waiting to be found? We’ll see.