Machine Theory of Mind

In lab meeting last week, we read Machine Theory of Mind, a recent paper from Neil Rabinowitz and his collaborators at DeepMind & Google Brain (a trimmed version of the paper was presented at ICML 2018). Here, Theory of Mind (ToM) is broadly defined as the ability to represent the mental states of others. This paper aims to demonstrate ToM in an artificial agent. While designing & training such an agent constitutes one challenge, the authors must first devise a scenario in which ToM can be convincingly shown. Inspired by the Sally-Anne test — a classic test of ToM from developmental psychology that evaluates whether a child understands that others can hold false beliefs — the authors construct an analogous test, then train an agent to successfully pass it.

The paper is composed of a series of experiments that build in complexity to this final test. Within each experiment are three key parts. First, the environment: a simple 11×11 grid-world containing walls and 4 colored boxes that are all randomly located within each new world. Second, the agents: an individual agent belongs to a particular “species” according to its policy for acting within an environment. Agents can behave randomly, algorithmically, or with a learned policy (via deep RL). The trajectory of a particular agent within a particular environment constitutes an episode. Reward within an episode is generally maximized by navigating to a box of a particular color as fast as possible. However, limitations on the sightedness and statefulness of the agents, as well as the inclusion of more complex subgoals, are adjusted per the needs of each experiment. Finally, the observer: a meta-learning agent, called ToMnet, that parses the episodes of many agents in many environments so as to learn a prior over the behavior of an agent species. At test time, ToMnet uses a novel agent’s recent episodes & its trajectory on a current episode thus far to infer a posterior and make predictions regarding the agent’s future behavior.

From Figure 4 in the paper. (a) An example environment. The red arrows represent the trajectory of a goal-driven agent on a previous episode. (b) A new environment, but the same agent as in (a). Using a prior learned over the agent’s species (e.g., agents of this species seek to consume a box of a specific color) in combination with this agent’s past episodes (e.g., this agent consumed a green box), ToMnet is tasked with predicting the agent’s future behavior. (c) For the new environment in (b), ToMnet outputs the next predicted action (go right) and the predicted color of the box that will eventually be consumed (green). (d) ToMnet also predicts the agent’s successor representation.

To probe ToM in ToMnet, the authors introduce a species of agent with both a limited field of view and a subgoal. For example, an agent that can only see the squares adjacent to it must first navigate to a “star” in the grid-world before finally navigating to the blue box to achieve maximum reward. In certain environments, the agent passes the blue box early in its initial search and so knows directly where to go after finding the star, even if the blue box is not visible to the agent from the star’s location. The test comes when the experimenter now swaps the locations of the boxes while the agent is on the star and the boxes are out of view. While the agent is blind to the swap, ToMnet is not. And so, the analogous Sally-Anne test arises: Will ToMnet not recognize that the swap occurred outside of the agent’s field of view, and thus mistakenly predict that it will move toward the new location of the blue box? Or, will ToMnet recognize that the agent maintains the false belief that no swap has occurred, and thus correctly predict that it will move toward the old location of the blue box?

ToMnet predicts behavior reflecting the agent’s false belief, successfully passing the test. Importantly, this finding is supplemented with results that show that ToMnet is sensitive to how different fields of view make an agent “differentially vulnerable to acquire false beliefs,” and that ToMnet still passes the test even if it had never seen swap events during training. Thus, ToMnet “learns that agents can act based on false beliefs,” providing a compelling proof-of-concept for Machine Theory of Mind.

Advertisements

Insights on representational similarity in neural networks with canonical correlation

For this week’s journal club, we covered “Insights on representational similarity in neural networks with canonical correlation” by Morcos, Raghu, and Bengio, NeurIPS, 2018.  To date, many different convolutional neural networks (CNNs) have been proposed to tackle the object recognition problem, including Inception (Szegedy et al., 2015), ResNet (He et al., 2016), and VGG (Simonyan and Zisserman, 2015). These networks have vastly different architectures but all achieve high accuracy. How can this be the case? One possibility is that although the architectures vary, the representations (i.e., the way these networks encode information about the objects of natural images) are very similar. 

To test this, we first need a metric of similarity. One approach has been “representation similarity analysis” (RSA) (Kriegskorte et al., 2008) which relies on distance matrices to test if two representations are similar. One potential problem with RSA is that some dimensions of the representations may be “noisy” (i.e., dimensions that do not pertain to encoding the input information). For example, during training, some dimensions of the activity of CNN neurons may vary substantially across epochs but are not relevant to encoding object information. These dimensions could mask the signal of relevant dimensions when analyzing a distance matrix. 

One way to avoid this is to try to directly identify the relevant dimensions, allowing us to ignore the noisy dimensions. The authors relied on an old but trusted method called canonical correlation analysis (CCA), which was developed way back in the 1930s (Hotelling, 1936)! CCA has been a handy tool in computational neuroscience, relating the activity of neurons across two populations (Semedo et al., 2014) as well as relating population activity to the output of model neurons (Susillo et al., 2015). Newer methods have been developed that are more appropriate for various problems. These include partial least squares (Höskuldsson, 1988), kernel CCA (Hardoon et al., 2004), as well as a method I developed for my own work called distance covariance analysis (DCA) (Cowley et al., 2017).  The common thread among all of these methods is that they identify dimensions that encode similar information among two or more datasets.

Overview of CCA. CCA is a close relative to linear regression, but whereas linear regression aims at prediction, CCA focuses on correlation—and thus is most suitable for cases in which the investigator seeks intuition of the data.  Given two datasets (e.g., \mathbf{X} \in \mathcal{R}^{k \times N} \textrm{ and }  \mathbf{Y} \in \mathcal{R}^{p \times N}, both centered, where N is the number of samples), CCA seeks to identify a pair of dimensions \mathbf{u} \in \mathcal{R}^k \textrm{ and } \mathbf{v} \in \mathcal{R}^p such that the Pearson’s correlation between the projections \mathbf{u}^T \mathbf{X} \textrm{ and } \mathbf{v}^T \mathbf{Y} is the largest. In other words, CCA identifies linear combinations of the variables in \mathbf{X} \textrm{ and } \mathbf{Y} that are the most linearly-related. CCA need not stop there—it can identify pairs of dimensions that monotonically decrease in correlation. In this way, we can ignore the dimensions with the smallest correlations (which likely are spurious). One fun fact about CCA is that any two identified dimensions in \mathbf{X} are uncorrelated: \textrm{corr}(\mathbf{u}_i^T \mathbf{X}, \mathbf{u}_j^T \mathbf{X}) = 0 \textrm{ for } i \neq j (and the same for \mathbf{v}_i, \mathbf{v}_j). This is different from PCA, whose identified dimensions are both uncorrelated and orthogonal.  The uncorrelatedness of CCA dimensions ensures that we do not include dimensions that contain redundant information. (Implementation details: CCA is solved with singular-value decomposition, but be sure to use a regularized form akin to ridge regression—it was unclear if the authors used regularization). 

Figure 1. Generalizing networks converge to more similar solutions than memorizing networks.

Onto the results. The authors proposed a distance metric of CCA to uncover some intuitive characteristics about deep neural networks. First, they found that different initializations of generalizing networks (i.e., networks trained on labeled natural images) were more similar than different initializations of memorizing networks (i.e., networks trained on the same dataset with randomly-shuffled labels). This is expected, as natural labels likely put a constraint on generalizing networks. Interestingly, when comparing generalizing and memorizing networks (Fig. 1, yellow line, ‘Inter’), they found that generalizing and memorizing networks were as similar as different memorizing networks trained on the same fixed dataset. This suggests that overfitted networks converge on very different solutions for the same problem. Also interesting was that earlier layers of both generalizing and memorizing networks seem to converge on similar solutions, while the later layers diverged. This suggests that earlier layers rely more on the structure of natural images while the later layers rely more on the structure of the labels. Second, they found that wider networks (i.e., networks with more filters per layer) converge to more similar solutions than those of narrower networks.  They argue that this supports the “lottery-ticket” hypothesis that wider networks are more likely to have a sub-network that fits the desired function.  Finally, they found that networks trained with different initializations and learning rates on the same problem converge to different groups of solutions. This highlights the need to try different initializations when training neural networks.

This paper left me thinking a lot about representation in the visual cortex of the brain. Does visual cortical population activity have stable and “noisy” dimensions?  If we reduced the number of visual cortical neurons per visual cortical area (either via lesion or pharmacological intervention) in a developing animal, would these animals have severe perceptual deficits (i.e., their visual system did not have the right lottery ticket when developing)?  Lastly, it seems plausible that humans start out with different initializations of their visual cortices—does that suggest different humans have converged on different solutions to solving visual perception?  If so, it suggests that inter-subject variability may be larger than previously thought.

The Loss-Calibrated Bayesian

By Farhan Damani

In lab meeting this week, we discussed loss-calibrated approximate inference in the context of Bayesian decision theory (Lacoste-Julien et. al. 2011, Cobb et. al. 2018). For many applications, the cost of an incorrect prediction can vary depending on the nature of the mistake. Suppose you are in charge of controlling a nuclear power plant with an unknown temperature \theta. We observe indirect measurements of the temperature D, and we use Bayesian inference to infer a posterior distribution over the temperature given the observations p(\theta|D). The plant is in danger of over-heating and as the operator, you can either keep the plant running or shut it down. Keeping the plant running while the plant’s temperature exceeds a critical threshold T_{\text{critic}} will cause a nuclear meltdown, incurring a huge loss L(\theta > T_{\text{critic}}, \text{'on'}) while shutting off the plant for benign temperatures incurs a minor loss L(\theta < T_{\text{critic}}, \text{'off'})

In figure 1 we observe the true posterior p(\theta|D) is multi-modal. Our suite of approximate inference techniques characterize general properties of the posterior, attempting to match either the first or second moment of p. Both strategies underestimate the posterior mass for the safety-critical region. Instead, the dash-dotted line, while failing to characterize typical properties of the posterior, results in the same decision as the true posterior by optimizing for task-specific utility. The point is the “best” approximate posterior is subjective, and therefore, we should tailor our inferential resources to find an approximation that is well suited for the decision task at hand.

Bayesian decision theory extends the Bayesian paradigm by including a task-specific utility function U(\theta, a), which tells us the utility of taking action a \in \mathcal{A} when the world is in state \theta. According to this view, the optimal action minimizes the posterior risk: \underset{a}{\arg \min} \text{ } \mathcal{R}(a) = \mathbb{E}_{p(\theta|D)}[U(\theta, a)]. Typically, this is computed using a 2-step procedure. First approximate the posterior p(\theta|D) with a q(\theta|D) and then minimize the risk under q. This approach, however, assumes our approximate q measures properties of the posterior that we care about. This by definition requires our utility function, so therefore, we should jointly optimize the approximate posterior with the action that minimizes the posterior risk. Cobb et. al. 2018 show how to derive a variational lower bound that depends on a task-specific utility function. In their setup, they show that minimizing the KL divergence between an approximate posterior q and a calibrated posterior scaled by the utility function results in the standard ELBO loss plus an additional utility-dependent regularization term. This formulation is amenable to stochastic optimization, allowing for the practical deployment of this framework to supervised learning.

An orderly single-trial organization of population dynamics in premotor cortex predicts behavioral variability

This week we read some new work from Shaul Druckmann and Karel Svoboda’s groups (https://www.biorxiv.org/content/early/2018/07/25/376830). They analyzed simultaneously recorded activity from the anterior lateral motor cortex (ALM; perhaps homologous to a premotor area in primates) from mice performing a delayed discrimination task (either a somatosensory pole detection task or an auditory tone discrimination task). They analyzed 55 sessions with 6-31 units on each session. Given the strong task epoch dependent responsiveness of most cells in the population, they fit a switching linear dynamical system (sLDS) to the data using expectation-maximazation. However, in their model, the switch times were dictated by the task structure, making the model significantly easier to fit than a sLDS with unconstrained switch times. They called their model a epoch-dependent linear dynamical system (EDLDS).

They used leave-one-neuron-out cross validation (i.e. compute the posterior on test trials without including one neuron’s activity, and then predict that neuron’s activity from the posterior) to test the model fit and found that it often fit the data about as well a sLDS that could flexibly assign the timing of the same number of switch events and significantly outperformed Gaussian process factor analysis.

The model defines a low-dimensional latent space to which they apply several analyses. First, they applied linear discriminate analysis (LDA) to decode the animal’s choice on each trial and show that it outperforms LDA applied to the activity of the full neural population, even when regularization is included.

Next, they applied principal components analysis (PCA) to the latent activity and the full neural population activity to visualize the dominant temporal trajectories within each space. PC projections of the full population activity showed sharp temporal transitions between task epochs and a random spatial ordering across trials, while PC projections of the latent activity showed smooth temporal transitions and strongly ordered dynamics.

Fig4a

They quantified the “orderliness” of each representation by computing the consistency of the trial-ranked value of the LDA projection across time to confirm greater orderliness within the latent space than the full neural activity. They also found that decode analyses to previous trial outcome or choice on error trials using the latent activity outperformed the same analyses applied to the full neural activity.

In summary, the dynamical nature of the EDLDS provides a smooth, de-noised portrait of the temporal dynamics present in the data but that might not easily reveal itself with standard analyses.

Motor Cortex Embeds Muscle-like Commands in an Untangled Population Response

This week we discussed a recent publication by Abigail Russo from Mark Churchland’s lab. The authors examined primary motor cortex (M1) population responses and EMG activity from primates performing a novel cycling task. The task required the animal to rotate a pedal (like riding a bicycle, except with one’s arm) a fixed number of rotations. There were several task conditions, including pedaling forward and backward.

The authors found that while M1 neural activity contained components of muscle activity (e.g. trial-averaged EMG activity could be accurately predicted by linear combinations of the trial-averaged neural activity) the dominant structure present in the neural population response was not muscle-like. They came to this conclusion by examining the top principal component projections of the neural activity and the EMG activity, where they found that the former co-rotated for forward and backward pedaling while the later counter-rotated. This discrepancy in rotation direction is inconsistent with the notation that neural activity encodes force or kinematic commands.

Based on this observation, the authors proposed a novel hypothesis: the dominant, non-muscle-like activity patterns in M1 exist so as to “detangle” the representation of the muscle-like activity patterns. A rough analogy would be something like a phonograph, whose dominant dynamics are rotating, but which only serve to lay out a coding direction (normal to the rotation) which can be “read-out” in a simple way by the phonograph needle. The authors show with network models that a “detangled” response has the desirable property of noise-robustness.

The following toy-model from the paper illustrates the idea of “tangled-ness.” Imagine that a population of neurons must generate output 1 and output 2 depicted below. If the population represented those signals directly (depicted in the leftmost phase portrait) in a 2-dimensional space, the trajectory would trace out a “figure-8,” a highly tangled trajectory and one that cannot be generated by an autonomous dynamical system (which the authors assume more-or-less accurately caricatures the dynamical properties of M1). In order to untangle the neural representation (depicted in the rightmost phase portrait), the neural activity needs to add an extra, third dimension which resides in the null space of the output. Now, these dynamics can be generated autonomously and a linear projection of them can generate the output.

Untitled

The authors directly compute a measure of tangling within the neural data and the EMG data. The metric is the following:

Q(t) = \text{max}_{t^\prime} \frac{|| \dot x_t - \dot x_{t^\prime}||}{|| x_t - x_{t^\prime} || + \epsilon}.

It can be summarized in the following way: identify two moments in time where the state is very similar, but where the derivative of the state is very different. Such points are exemplified by the intersection of the “figure-8” trajectory above, since the intersection is two identical states with very different derivatives. Across multiple animals, species and motor tasks the authors found a consistent relationship: neural activity is less tangled than EMG activity (as shown below). The authors note that a tangled EMG response is acceptable or perhaps even desirable, since EMG reflects incoming commands and therefore does not need to abide by the requirements that an autonomous dynamical system (like M1) does.

Untitled2

Based on these analyses, the authors conclude that the dominant signals present in M1 population activity principally perform a computational role, by untangling the representation of muscle-like signals that can be read-out approximately linearly by the muscles.

Automatic Differentiation in Machine Learning: a Survey

 

In lab meeting this week we discussed Automatic Differentiation in Machine Learning: a Survey, which addresses the general technique of autodifferentiation for functions expressed in computer programs.

The paper initially outlines three main approaches to the computation of derivatives in computer programs: 1) manually working out derivates by hand and coding the result. 2) using numerical approximations to derivates in the form of assessing function values at small discrete steps, and 3) computer manipulation of symbolic mathematic expressions that automatically generates differential expressions via standard calculus rules. The limitations of approach 1 are that derivatives can involve the calculation of a large number of terms (expression swell) that are both manually cumbersome to deal with and lend themselves to easy algebraic mistakes. The disadvantage of approach 2 is a finite difference approach to numerical differentiation inevitably involves round-off errors and is sensitive to step-size. The disadvantage to 3 is that this approach is computationally cumbersome and often standard calculus rules involve forms of derivates that still need to be manually reduced.

In contrast to these methods, the paper introduces autodifferentiation as a set of techniques that operates at the elementary computation level via step-wise implementation of the chain rule. The assessment of a function in a computer program involves an execution of procedures that can be represented via a computational graph where each node in the graph is an intermediary temporary variable. Computers execute this trace of elementary operations via intermediary variables and this procedure can be utilized efficiently to additionally concurrently calculate a derivative.

The paper outlines two general approaches to autodifferentiation: forward mode and backward mode. In forward mode, a derivative trace is evaluated alongside the computers execution of the functional trace, which intermediary values from the functional evaluation trace are used in concert with the derivative trace calculation. This execution can be conveniently realized by extending the representation of the real numbers to include a dual component. Analogous to imaginary numbers, this dual component is denoted with an \epsilon, where \epsilon^2 = 0. That is, each number in the execution of the function on a computer is now extended to include tuple which includes this additional dual component whose expressions evaluate using intermediate variables alongside the evaluation of the function.  Using this simple rule of the dual number (\epsilon^2 = 0), evaluations implement the notion of the chain rule in differentiation to automatically calculate derivatives at every computational step along the evaluation trace.

In executing this forward mode autodifferentiation, a derivative trace is initially ‘seeded’ with an input derivative vector as a function is evaluated with an initial input. Whereas symbolic differentiation can be thought of as the mapping of f(x) to J_{f}(x), forward autodifferentiation is the mapping of f(c) to J_{f}(c) \cdot \vec{x}. As such, considering a functional mapping from dimension n to m, forward mode requires n trace evaluations to calculate the entire Jacobian, each time the seeded derivative vector is an index for a particular dimension in n. Thus, the calculation of the Jacobian can be completed in approximately n times the total number of operations in an evaluation of f(c).

Backward mode, by contrast, computes f(c) to \vec{y} \cdot J_{f}(c). That is, one evaluation of the derivative trace in backward mode can similarly compute a Jacobian vector dot product, in this case extracting a row of the Jacobian (as opposed to a column as in forward mode). In this case, the calculation of the full Jacobian can be completed in approximately m times the total number of operations in an evaluation of f(c).  A disadvantage to backward mode autodifferentiation is that each intermediate variable along a function evaluation trace must be stored in memory for use during the backward execution of the derivative trace.

Stochastic variational learning in recurrent spiking networks

This week we discussed Stochastic variational learning in recurrent spiking networks by Danilo Rezende and Wolfram Gerstner.

Introduction

This paper brings together variational inference (VI) and biophysical networks of spiking neurons. The authors show:

  1. variational learning can be implemented by networks of spiking neurons to learn generative models of data,
  2. learning takes the form of a biologically plausible learning rule, where local synaptic learning signals are augmented with a global “novelty” signal.

One potential application the authors mention is to use this method to identify functional networks from experimental data. Through the course of the paper, some bedrock calculations relevant to computational neuroscience and variational inference are performed. These include computing the log likelihood of a population of spiking neurons with Poisson noise (including deriving the continuum limit from discrete time) and derivation of the score function estimator. I’ve filled in some of the gaps in these derivations in this blog post (plus some helpful references I consulted) for anyone seeing this stuff for the first time.

Neuron model and data log likelihood

The neuron model used in this paper is the spike response model which the authors note (and we discussed at length) is basically a GLM. The membrane potential of each unit in the network is described by the following equation:

\mathbf{u} = \mathbf{w \phi(t)} + \mathbf{\eta(t)}

where \mathbf{u} is a N-dimensional vector, \mathbf{\phi}(t) are exponentially filtered synaptic currents from the other neurons, w is a N \times N matrix of connections and \mathbf{\eta}(t) is an adaptation potential that mediates the voltage reset when a neuron spikes (this can be thought of as an autapse).

Figure1

Spikes are generated by defining an instantaneous firing rate \rho(t) = \rho_0 \text{exp}[\frac{\mathbf{u} - \theta}{\Delta u}] where \theta, \Delta u and \rho_0 are physical constants. The history of all spikes from all neurons ​is denoted by \mathbf{X}. We can define the probability the i^{th} neuron producing a spike in the interval [t,t+\Delta t], conditioned on the past activity of the entire network \mathbf{X}(0...t) as P_i(t_i^f \in [t,t+\Delta t] | \mathbf{X}(0...t)) \approx \rho_i(t) \Delta tand the probability of not producing a spike as P_i(t_i^f \notin [t,t+\Delta t] | \mathbf{X}(0...t)) \approx 1 - \rho_i(t) \Delta t.

Aside: in future sections, the activity of some neurons will be observed (visible) and denoted by a super- or subscript \mathcal{V} and the activity of other neurons will be hidden and similarly denoted by \mathcal{H}. 

We can define the joint probability of the entire set of spikes as: 

P(X(0...T)) \approx \Pi_{i \in \mathcal{V} \cup \mathcal{H}} \Pi_{k_i^s} [\rho_i(t^f_{k_i^s})\Delta t] \Pi_{k_i^{ns}} [1 - \rho_i(t^f_{k_i^{ns}})\Delta t]

The authors re-express this in the continuum limit. A detailed explanation of how to do this can be found in Abbott and Dayan, Chapter 1, Appendix C. The key is to expand the log of the “no spike” term into a Taylor series, truncate at the first term and then exponentiate it:

\text{exp} \hspace{1mm} \text{log} (\Pi_{k_i^{ns}} [1 - \rho_i(t^f_{k_i^{ns}})\Delta t]) = \text{exp} \sum_{k_i^{ns}} \text{log} [1 - \Delta t \rho_i(t_{k_i^{ns}}^f)] \approx \text{exp} \sum_{k_i^{ns}} - \Delta t \rho_i(t_{k_i^{ns}}^f).

As \Delta t \rightarrow 0, this approximation becomes exact and the sum across bins with no spikes becomes an integral across time, 

P(\mathbf{X}(0...T)) = \Pi_{i \in \mathcal{V} \cup \mathcal{H}} [\Pi_{t_i^f}\rho_i(t_i^f) \Delta t] \text{exp}(- \int_0^T dt \rho_i(t))

from which we can compute the log likelihood,

\text{log} P(\mathbf{X}(0...T) = \sum_{i \in \mathcal{V} \cup \mathcal{H}} \int_0^T d\tau [\text{log} \rho_i(\tau) \mathbf{X}_i(\tau) - \rho_i(\tau)]

where we use \mathbf{X}_i(\tau) to identify the spike times.

They note that this equation is not a sum of independent terms, since \rho depends on the entire past activity of all the other neurons.

Figure2

Figure 2 from the paper show the relevant network structures we will focus on. Panel C shows the intra- and inter- network connectivity between and among the hidden and visible units. Panel D illustrates the connectivity for the “inference” \mathcal{Q} network and the “generative” \mathcal{M} network. This structure is similar to the Helmholtz machine of Dayan (2000) and the learning algorithm will be very close to the wake-sleep algorithm used there.

Variational Inference with stochastic gradients

From here, they follow a pretty straightforward application of VI. I will pepper my post with terms we’ve used/seen in the past to make these connections as clear as possible. They construct a recurrent network of spiking neurons where the spiking data of a subset of the neurons (the visible neurons or “the observed data”) can be explained by the activity of a disjoint subset of unobserved neurons (or a “latent variable” ala the VAE). Like standard VI, they want to approximate the posterior distribution of the spiking patterns of the hiding variables (like one would approximate the posterior of a latent variable in a VAE) by minimizing the KL-divergence between the true posterior and an approximate posterior q:

KL(q;p) = \int \mathcal{D} \mathcal{X_H} q(\mathcal{X_H} | \mathcal{X_V}) \text{log} \frac{q(\mathcal{X_H} | \mathcal{X_V})}{p(\mathcal{X_H} | \mathcal{X_V})}

= \langle \text{log} q(\mathcal{X_H} | \mathcal{X_V}) - \text{log} p(\mathcal{X_H,X_V}) \rangle_{q(\mathcal{X_H} | \mathcal{X_V})} + \text{log} p(\mathcal{X_V})

= \langle \mathcal{L^Q} - \mathcal{L^M} \rangle_{q(\mathcal{X_H} | \mathcal{X_V})} + \text{log} p(\mathcal{X_V}).

The second term is the data log likelihood. The first term, \mathcal{F}, is the Helmholtz free energy and like always in VI it represents an upper bound on the negative log likelihood. We can therefore change our optimization problem to minimize this function with respect to the parameters of q (the approximate posterior) and p (the true posterior). We do this by computing the gradients of \mathcal{F} with respect to the \mathcal{Q} (inference) network and the \mathcal{M} (generative) network​. First, the \mathcal{M} network, since it’s easier:

\dot{w_{ij}^\mathcal{M}} = -\mu^\mathcal{M} \nabla_{w_{ij}^\mathcal{M}} \mathcal{F} = \mu^\mathcal{M} \nabla_{w_{ij}^\mathcal{M}} \langle \mathcal{L^Q} - \mathcal{L^M} \rangle_q = \mu^\mathcal{M} \langle \nabla_{ij}^\mathcal{M} \mathcal{L^M} \rangle_q \approx \mu^\mathcal{M} \nabla_{w_{ij}^\mathcal{M}} \hat{\mathcal{L}}^\mathcal{M}

where \hat{\mathcal{L}}^\mathcal{M} is a point estimate of the complete data log likelihood of the generative model. They will compute this with a Monte Carlo estimate. The gradient of the complete data log likelihood with respect to the connections is:

\nabla_{w_{ij}^\mathcal{M}} \hat{\mathcal{L}}^\mathcal{M} = \nabla_{w_{ij}^\mathcal{M}} \text{log} p(\mathcal{X_H, X_V}) = \sum_{k \in \mathcal{V} \cup \mathcal{H}} \int_0^T d\tau \frac{\partial \text{log} \rho_k(\tau)}{\partial w_{ij}^\mathcal{M}} [\mathbf{X}_k(\tau) - \rho_k(\tau)].

Here they used the handy identity: \frac{\partial f(x)}{\partial x} = \frac{\partial[\text{log} f(x)]}{\partial x} f(x)

The derivative of the firing rate function can be computed with the chain rule, \frac{\partial [\text{log} \rho_k(\tau)]}{\partial w_{ij}^\mathcal{M}} = \delta_{ki} \frac{g^\prime(u_k(\tau))}{g(u_k(\tau))} \phi_j(\tau).

This equation for updating the weights using gradient ascent is purely local, taking the form of a product between a presynaptic component, \phi_j(\tau), and a postsynaptic term \frac{g^\prime(u_k(\tau))}{g(u_k(\tau))} [\mathbf{X}_k(\tau) - \rho_k(\tau)].

They also compute the gradient of \mathcal{F} with respect to the \mathcal{Q} network, -\mu^\mathcal{Q} \nabla_{w_{ij}^\mathcal{Q}} \mathcal{F}. To do this, I revisited the 2014 Kingma and Welling paper, where I think they were particularly clear about how to compute gradients of expectations (i.e. the score function estimator). In section 2.2 they note that:

\nabla_\phi \langle f(z) \rangle_{q_\phi (z)}= \langle f(z) \nabla_\phi \text{log} q_\phi (z) \rangle .

A cute proof of this can be found hereThis comes in handy when computing the gradient of \mathcal{F} with respect to the connections of the \mathcal{Q} network:

\nabla_{w_{ij}^\mathcal{Q}} \mathcal{F} = \langle \mathcal{F} \nabla_{w_{ij}^\mathcal{Q}} \text{log} q(\mathcal{X_H} | \mathcal{X_V}) \rangle = \langle \mathcal{F} \nabla_{w_{ij}^\mathcal{Q}} \mathcal{L^Q} \rangle \approx \hat{\mathcal{F}} \nabla_{w_{ij}^\mathcal{Q}} \hat{\mathcal{L}}^\mathcal{Q}.

Here again we compute Monte Carlo estimators of \mathcal{F} and \mathcal{L^Q}. \nabla_{w_{ij}^\mathcal{Q}} \hat{\mathcal{L}}^\mathcal{Q} takes the exact same form as for the \mathcal{M} network, but the neat thing is that \nabla_{w_{ij}^\mathcal{Q}} \mathcal{F} contains a term in front of the gradient of the estimate of the log likelihood, \hat{\mathcal{F}}. This is a global signal (opposed to the local signals that are present in \hat{\mathcal{L}}^\mathcal{Q}) that they interpret as a novelty or surprise signal.

Reducing gradient estimation variance

The authors note that the stochastic gradient they introduced has been used extensively in reinforcement learning and that its variance is prohibitively high. To deal with this (presumably following the approach others have developed in RL, vice versa) they adopt a simple, baseline removal approach. They subtract the mean \bar{\mathcal{F}} of the free energy estimate \hat{\mathcal{F}} calculated as a moving average across several previous batches of length T from the current value \hat{\mathcal{F}}(T). They replace the free energy in the gradient for the \mathcal{Q} network with a free energy error signal, \hat{\mathcal{F}}(T) - \bar{\mathcal{F}}. Below, the log likelihood of the generated data when this procedure is used is plotted against a naively trained network, showing that this procedure works better than the naive rule.

Figure5a

Numerical results

Details of their numerical simulations:

  • Training data is binary arrays of spike data.
  • Training data comes in batches of 200 ms with 500 batches sequentially shown to the network (100 s of data).
  • During learning, visible neurons are forced to spike like the training data.
  • Log likelihood of test data was estimated with importance sampling. Given a generative model with density p(x_v,x_h), importance sampling allows us to estimate the density of p(x_v):

p(x_v) = \langle p(x_v | x_h) \rangle_{p(x_h)} = \langle p(x_v | x_h) \frac{p(x_h)}{q(x_h|x_v)} \rangle_{q(x_h|x_v)}

= \langle \text{exp}[\text{log}p(x_v,x_h) - \text{log} q(x_h|x_v)] \rangle_q = \langle \text{exp}[-\hat{\mathcal{F}}(x_v,x_h)]\rangle_{q(x_v|x_h)}

Using this equation, they estimate the log likelihood of the observed spike trains by sampling several times from the \mathcal{Q} network and computing the average of the free energy. They use 500 samples of duration 100 from the \mathcal{Q} network to compute this estimate.

Here is an example of training with this method with 50 hidden units using the “stairs” dataset. C shows that the network during the “sleep phase” (running in generative mode) forms a latent representation of the stairs in the hidden layers. Running the network in “inference mode” (wake, in the wake-sleep parlance), when the \mathcal{Q} network synapses are being used, the model is capable of performing inference on the causes of the incoming data (the visible neurons are being driven with the data).

Figure3

Role of the novelty signal

To examine the role of the novelty signal, they train a network to perform a maze task. Each maze contains 16 rooms where each room is a 28×28 pixel greyscale image of a MNIST digit. Each room is only accessible from a neighboring room. Pixel values were converted into firings rates from 0.01 to 9 Hz. In the test maze (or control maze), some of the rooms of the training maze were changed. The network had 28×28 visible units and 30 hidden units. These were recurrent binary units. Data were generated from random trajectories of 100 time steps in the target maze. Each learning epoch was 500 presentations of the data batches.

Below, (bottom left) they plotted the slow moving average of the free energy \bar{\mathcal{F}} as a function of the amount of observed data for the target maze (blue) and the same model when it was “teleported” to the control maze every 500 s. In the beginning of learning, the free energy is the same so the model cannot distinguish between them. As learning proceeds, the model identifies the test as unfamiliar (higher free energy).

Bottom right shows the free energy error signal for the sample trajectory in A. It fluctuates near zero for the learned maze but deviates largely for the test maze. We can see at (3,3) the free energy signal really jump up, meaning that the model identifies this as different from the target.

To conclude, the authors speculate that a neural correlate of this free energy error signal should look like an activity burst when an animal traverses unexpected situations. Also, they expect to see a substantial increase in the variance of the changes in synaptic weights when moving from a learned to a unfamiliar maze due to the change in the baseline of surprise levels.

Figure6.png