In lab meeting this week, we read Deep Neural Networks as Gaussian Processes by Lee, Bahri, Novak, Schoenholz, Pennington and Sohl-Dickstein, and which appeared at ICLR 2018. The paper extends a result derived by Neal (1994); and the authors show that there is a correspondence between deep neural networks and Gaussian processes. After coming up with an efficient method to evaluate the associated kernel, the authors compared the performance of their Gaussian process model with finite width neural networks (trained with SGD) on an image classification task (MNIST, CIFAR-10). They found that the performance of the finite width networks approached that of the Gaussian process performance as the width increased, and that the uncertainty captured by the Gaussian process correlated with mean squared prediction error. Overall, this paper hints at new connections between Gaussian processes and neural networks; and it remains to be seen whether future work can harness this connection in order to extend Gaussian process inference to larger datasets, or to endow neural networks with the ability to capture uncertainty. We look forward to following progress in this field.
Single Layer Neural Networks as Gaussian Processes – Neal 1994
Let us consider a neural network with a single hidden layer. We can write the ith output of the network, , as
where is the post-activation of the jth neuron in the hidden layer;
is some nonlinearity, and
is the kth input to the network.
If we now assume that the weights for each layer in the network are sampled i.i.d. from a Gaussian distribution: ,
; and that the biases are similarly sampled:
and
; then it is possible to show that, in the limit of
,
, for a kernel
which depends on the nonlinearity. In particular, this follows from application of the Central Limit Theorem: for a fixed input to the network
,
as
where
(which is the same for all
).
We can now apply a similar argument to the above in order to examine the distribution of ith output of the network for a collection of inputs: that is we can examine the joint distribution of . Application of the Multidimensional Central Limit Theorem tells us that, in the limit of
,
,
where and
and
.
Since we get a joint distribution of this form for any finite collection of inputs to the network, we can write that , as this is the very definition of a Gaussian process.
This result was shown in Neal (1994); and the precise form of the kernel was derived for the error function (a form of sigmoidal activation function) and Gaussian nonlinearities in Williams (1997).
Deep Neural Networks as Gaussian Processes
Lee et al. use similar arguments to those presented in Neal (1994) to show that the th output of the
th layer of a network with a Gaussian prior over all of the weights and biases is a sample from a Gaussian process in the limit of
. They use an inductive argument of the form: suppose that
(the jth output of the
th layer of the network is sampled from a Gaussian process). Then:
is Gaussian distributed as and any finite collection of
will have a joint multivariate Gaussian distribution, i.e.,
where
.
If we assume a base kernel of the form , these recurrence relations can be solved in analytic form for the ReLU nonlinearity (as was demonstrated in Cho and Saul (2009)), and they can be solved numerically for other nonlinearities (and Lee et al., give a method for finding the numerical solution efficiently).
Comparison: Gaussian Processes and Finite Width Neural Networks
Lee et al. went on to compare predictions made via Gaussian process regression with the kernels obtained by solving the above recurrence relations (for nonlinearities ReLU and tanh), with the predictions obtained from finite width neural networks trained with SGD. The task was classification (reformulated as a regression problem) of MNIST digits and CIFAR-10 images. Overall, they found that their “NNGP” often outperformed finite width neural networks with the same number of layers for this task; and they also found that the performance of the finite width networks often approached that of the NNGP as the width of these networks was increased:
