oi-VAE: output interpretable VAEs

We recently read oi-VAE: Output Interpretable VAEs for Nonlinear Group Factor Analysis, which was published at ICML in 2018 by Samuel Ainsworth, Nicholas Foti, Adrian Lee, and Emily Fox. This paper proposes a nonlinear group factor analysis model that is an adaptation of VAEs to data with groups of observations. The goal is to identify nonlinear relationships between groups and to learn latent dimensions that control nonlinear interactions between groups. This encourages disentangled latent representations among sets of functional groups. A prominent example in the paper is motion capture data, where we desire a generative model of human walking and train on groups of observed recorded joint angles. 

Let us consider observations \mathbf{x} that we group into G groups \mathbf{x} = [\mathbf{x}^{(1)}, \cdots, \mathbf{x}^{(G)}].  We note that the paper does not discuss how to choose the groups and assumes that a grouping has already been specified. The generative model maps a set of latent variables to group-specific generator networks via group-specific mapping matrices \mathbf{W}^{(g)}  such that

\mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) \\ \mathbf{x}^{(g)} \sim \mathcal{N}( f_{\theta_g}^{(g)} (\mathbf{W}^{(g)} \mathbf{z} ))

for each g.

oi-vae

oi-VAE generative model (Ainsworth et al., 2018)

For learning interpretable sets of latents that control separate groups, the key feature of this approach is to place a sparsity-inducing prior on the columns of each \mathbf{W}^{(g)}. The authors use a hierarchical Bayesian sparse prior that when analytically marginalized corresponds to optimizing a group lasso penalty on the columns of \mathbf{W}^{(g)}.

The model is trained in the standard VAE approach by optimizing the ELBO with an amortized inference network q_{\phi}(\mathbf{z} | \mathbf{x}), with the addition of the group lasso penalty and a prior on the parameters of the generator networks

\mathcal{L}(\phi, \theta, \mathcal{W}) = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} [ \log p(\mathbf{x} | \mathbf{z}, \mathcal{W}, \theta)] - D_{\mathrm{KL}} ( q_\phi(\mathbf{z} | \mathbf{x})  || p(\mathbf{z}) ) + \log p(\theta) - \lambda \sum_{g,j} \| w_j^{(g)} \|_2

where w_j^{(g)} is the j-th column of \mathbf{W}^{(g)} and \lambda is a parameter controlling the sparsity. The prior \log p(\theta) fixes the scaling of the neural network parameters relative to the mapping matrices \mathbf{W}^{(g)}.

The above objective consists of a differentiable term (the ELBO plus log prior on \theta) plus a convex but non-differential term (group LASSO). Therefore the authors use proximal gradient methods to optimize it. First, they update all parameters using the gradients of the ELBO plus log-prior with respect to \phi, \theta, and \mathcal{W}. Then, they apply the proximal operator

w_{j}^{(g)} \leftarrow \frac{ w_{j}^{(g)} } { \| w_{j}^{(g)} \|_2 } ( \| w_{j}^{(g)} \|_2 - \eta \lambda )_+

to each \mathbf{W}^{(g)} to respect the group lasso penalty, where \eta is a step-size. The authors fixed \lambda for all of their experiments fitting oi-VAE, so one question I had reading was how the authors determined \lambda and how varying \lambda affects the results.

The authors validate the approach on a toy example. They generated synthetic bars data, where one row of a square matrix was sampled from a Gaussian with non-zero mean while the rest of the matrix was sampled from zero-mean noise. The authors fit oi-VAE with each group set to a row of observations, and found that the model learned the appropriate sparse mapping where each latent component mapped to one of the rows. This latent space improved on the VAE, which did not have any discernible structure in the latent space. Importantly, oi-VAE still successfully identified the correct number of latent components (8) and sparse mapping even when the model was fit with double the amount of components.

 

row_data

(a) Synthetic bar data example and (b) reconstruction from oi-VAE. The learned oi-VAE latent-group mappings (c) match the true structure, while a VAE (d) does not learn a sparse structure. 

 

After validating the approach, the authors applied it to motion capture data. Here, the output groups were different joint angles. They trained the oi-VAE model on walking motion capture data. The learned latent dimensions nicely corresponded to intuitive groups of joint angles, such as the left leg (left foot, left lower leg, left upper leg). Next, the imposed structure in the model helped it generate more realistic unconditional samples of walking than the VAE, presumably because the inductive bias allowed oi-VAE to better learn invalid joint angles.

walking_samples

Unconditional walking pose samples from the VAE and oi-VAE models.

These results suggest the oi-VAE is a useful model for discovering nonlinear interactions between groups. In particular, I liked the approach of adding structure in the generative model to gain interpretability, and hypothesize that adding other forms of structure to VAE generative models is a good way to encourage disentangled representations (see a recent example of this in Dieng et al., 2019).

Two questions when using the approach are how to choose \lambda and how to choose the grouping of the data, as this work assumes a grouping has been chosen. In some data we may have prior knowledge about a natural grouping structure but that will not always be the case. However, even without multiple groups, the approach could be useful for learning the number of latent dimensions useful for explaining the amount of data. Finally, we point the reader to factVAE, where the authors further develop this idea to simultaneous learn complementary sparse structure in the inference network and generative model.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s