A Mean Field Theory of Batch Normalization

Greg Yang
Jascha Sohl-dickstein
Jeffrey Pennington
Sam Schoenholz
Vinay Rao
ICLR 2019(2019)
Google Scholar

Abstract

We develop a mean field theory for batch normalization in fully-connected feedforward neural networks, providing a precise characterization of the dynamics of signal propagation and gradient backpropagation in wide batch-normalized networks at initialization. We find that gradient signals grow exponentially in depth and that these exploding gradients cannot be eliminated by tuning the initialization hyper-parameters or by adjusting the nonlinear activation function. Indeed, batch normalization itself is the cause of exploding gradients. As a result, vanilla batch-normalized networks without skip connections are not trainable at large depths, a prediction that we verify with a variety of empirical simulations. While gradient explosion cannot be eliminated, we show that by pushing the network closer to the linear regime gradient explosion can be ameliorated allowing for the training of deep networks without residual connections.

Research Areas