Deep Neural Networks as Gaussian Processes

In lab meeting this week, we read Deep Neural Networks as Gaussian Processes by Lee, Bahri, Novak, Schoenholz, Pennington and Sohl-Dickstein, and which appeared at ICLR 2018. The paper extends a result derived by Neal (1994); and the authors show that there is a correspondence between deep neural networks and Gaussian processes. After coming up with an efficient method to evaluate the associated kernel, the authors compared the performance of their Gaussian process model with finite width neural networks (trained with SGD) on an image classification task (MNIST, CIFAR-10). They found that the performance of the finite width networks approached that of the Gaussian process performance as the width increased, and that the uncertainty captured by the Gaussian process correlated with mean squared prediction error. Overall, this paper hints at new connections between Gaussian processes and neural networks; and it remains to be seen whether future work can harness this connection in order to extend Gaussian process inference to larger datasets, or to endow neural networks with the ability to capture uncertainty. We look forward to following progress in this field.

Single Layer Neural Networks as Gaussian Processes – Neal 1994

Let us consider a neural network with a single hidden layer. We can write the ith output of the network, z_{i}^{1}, as

z_{i}^{1}(x) = b_{i}^{1} + \sum_{j}^{N_{1}}W_{ij}^{1}x_{j}^{1}(x)

where x_{j}^{1}(x) = \phi(b_{j}^{0} + \sum_{k}^{d_{in}}W_{jk}^{0}x_{k}) is the post-activation of the jth neuron in the hidden layer; \phi(x) is some nonlinearity, and x_{k} is the kth input to the network.

If we now assume that the weights for each layer in the network are sampled i.i.d. from a Gaussian distribution: W_{ij}^{1} \sim \mathcal{N}(0, \dfrac{\sigma_{w}^{2}}{N_{1}}), W_{ij}^{0} \sim \mathcal{N}(0, \dfrac{\sigma_{w}^{2}}{d_{in}}); and that the biases are similarly sampled: b_{i}^{1} \sim \mathcal{N}(0, \sigma_{b}^{2}) and b_{i}^{0} \sim \mathcal{N}(0, \sigma_{b}^{2}); then it is possible to show that, in the limit of N_{1} \rightarrow \infty, z_{i} \sim \mathcal{GP} (0, K_{\phi}), for a kernel K_{\phi} which depends on the nonlinearity. In particular, this follows from application of the Central Limit Theorem: for a fixed input to the network \vec{x}, z_{i}^{1} (\vec{x}) \rightarrow \mathcal{N}(0, \sigma_{b}^{2}+\sigma_{w}^{2}V_{\phi}(x^{1}(\vec{x}))) as N_{1} \rightarrow \infty where V_{\phi}(x^{1}(\vec{x})) \equiv \mathbb{E}[(x^{1}_{i}(\vec{x}))^{2}] (which is the same for all i).

We can now apply a similar argument to the above in order to examine the distribution of ith output of the network for a collection of inputs: that is we can examine the joint distribution of \{z_{i}^{1}(\vec{x}^{\alpha = 1}), z_{i}^{1}(\vec{x}^{\alpha = 2}), ..., z_{i}^{1}(\vec{x}^{\alpha = k})\}. Application of the Multidimensional Central Limit Theorem tells us that, in the limit of N_{1} \rightarrow \infty,

\{z_{i}^{1}(\vec{x}^{\alpha = 1}), z_{i}^{1}(\vec{x}^{\alpha = 2}), ...,z_{i}^{1}(\vec{x}^{\alpha = k})\} \sim \mathcal{N}(0, K_{\phi}),

where K_{\phi} \in \mathbb{R}^{k \times k} and K_{\phi}(\vec{x}, \vec{x}') \equiv \sigma_{b}^{2} + \sigma_{w}^{2}C_{\phi}(\vec{x}, \vec{x'}) and C_{\phi}(\vec{x}, \vec{x'}) \equiv \mathbb{E}[x_{i}^{1}(\vec{x})x_{i}^{1}(\vec{x}')].

Since we get a joint distribution of this form for any finite collection of inputs to the network, we can write that z_{i}^{1} \sim \mathcal{GP}(0, K_{\phi}), as this is the very definition of a Gaussian process.

This result was shown in Neal (1994); and the precise form of the kernel K_{\phi} was derived for the error function (a form of sigmoidal activation function) and Gaussian nonlinearities in Williams (1997).

Deep Neural Networks as Gaussian Processes

Lee et al. use similar arguments to those presented in Neal (1994) to show that the ith output of the lth layer of a network with a Gaussian prior over all of the weights and biases is a sample from a Gaussian process in the limit of N_{l} \rightarrow \infty. They use an inductive argument of the form: suppose that z_{j}^{l-1} \sim \mathcal{GP}(0, K_{\phi}^{l-1}) (the jth output of the (l-1)th layer of the network is sampled from a Gaussian process). Then:

z_{i}^{l} \equiv b_{i}^{l} + \sum_{j=1}^{N_{l}}W_{ij}^{l}x_{j}^{l}(\vec{x})

is Gaussian distributed as N_{l} \rightarrow \infty and any finite collection of \{z_{i}^{l}(\vec{x}^{\alpha=1}), ..., z_{i}^{l}(\vec{x}^{\alpha=k})\} will have a joint multivariate Gaussian distribution, i.e., z_{i}^{l} \sim \mathcal{GP}(0, K_{\phi}^{l}) where

K_{\phi}^{l}(\vec{x}, \vec{x'}) \equiv \mathbb{E}[z_{i}^{l}(\vec{x})z_{i}^{l}(\vec{x'})] = \sigma_{b}^{2} + \sigma_{w}^{2} \mathbb{E}_{z_{i}^{l-1}\sim \mathcal{GP}(0, K_{\phi}^{l-1})}[\phi(z_{i}^{l-1}(\vec{x})) \phi(z_{i}^{l-1}(\vec{x'}))].

If we assume a base kernel of the form K^{0}(\vec{x}, \vec{x'}) \equiv \sigma_{b}^{2} + \sigma_{w}^{2}(\dfrac{\vec{x}\cdot \vec{x'}}{d_{in}}), these recurrence relations can be solved in analytic form for the ReLU nonlinearity (as was demonstrated in Cho and Saul (2009)), and they can be solved numerically for other nonlinearities (and Lee et al., give a method for finding the numerical solution efficiently).

Comparison: Gaussian Processes and Finite Width Neural Networks

Lee et al. went on to compare predictions made via Gaussian process regression with the kernels obtained by solving the above recurrence relations (for nonlinearities ReLU and tanh), with the predictions obtained from finite width neural networks trained with SGD. The task was classification (reformulated as a regression problem) of MNIST digits and CIFAR-10 images. Overall, they found that their “NNGP” often outperformed finite width neural networks with the same number of layers for this task; and they also found that the performance of the finite width networks often approached that of the NNGP as the width of these networks was increased:

Figure 1 of Lee et al. The authors compare the performance of their NNGP to finite width neural networks of the same depth and find that, for many tasks, the NNGP outperforms the finite width networks and that the performance of the finite width networks approaches that of the NNGP as the width is increased.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s