MetaInit: Initializing learning by learning to initialize
Abstract
Deep learning models frequently trade handcrafted features for deep features
learned with much less human intervention using gradient descent. While this
paradigm has been enormously successful, deep networks are often difficult to
train and performance can depend crucially on the initial choice of parameters. In
this work, we introduce an algorithm called MetaInit as a step towards automating
the search for good initializations using meta-learning. Our approach is based on
a hypothesis that good initializations make gradient descent easier by starting in
regions that look locally linear with minimal second order effects. We formalize
this notion via a quantity that we call the gradient quotient, which can be computed
with any architecture or dataset. MetaInit minimizes this quantity efficiently
by using gradient descent to tune the norms of the initial weight matrices. We
conduct experiments on plain and residual networks and show that the algorithm
can automatically recover from a class of bad initializations. MetaInit allows us
to train networks and achieve performance competitive with the state-of-the-art
without batch normalization or residual connections. In particular, we find that
this approach outperforms normalization for networks without skip connections on
CIFAR-10 and can scale to Resnet-50 models on Imagenet.
learned with much less human intervention using gradient descent. While this
paradigm has been enormously successful, deep networks are often difficult to
train and performance can depend crucially on the initial choice of parameters. In
this work, we introduce an algorithm called MetaInit as a step towards automating
the search for good initializations using meta-learning. Our approach is based on
a hypothesis that good initializations make gradient descent easier by starting in
regions that look locally linear with minimal second order effects. We formalize
this notion via a quantity that we call the gradient quotient, which can be computed
with any architecture or dataset. MetaInit minimizes this quantity efficiently
by using gradient descent to tune the norms of the initial weight matrices. We
conduct experiments on plain and residual networks and show that the algorithm
can automatically recover from a class of bad initializations. MetaInit allows us
to train networks and achieve performance competitive with the state-of-the-art
without batch normalization or residual connections. In particular, we find that
this approach outperforms normalization for networks without skip connections on
CIFAR-10 and can scale to Resnet-50 models on Imagenet.