Kronecker-factored Curvature Approximations for Recurrent Neural Networks

James Martens
Jimmy Lei Ba
Matthew Johnson
ICLR (2018)

Abstract

Kronecker-factor Approximate Curvature (Martens & Grosse, 2015) (K-FAC) is
a 2nd-order optimization method which has been shown to give state-of-the-art
performance on large-scale neural network optimization tasks (Ba et al., 2017). It
is based on an approximation to the Fisher information matrix (FIM) that makes
assumptions about the particular structure of the network and the way it is parameterized.
The original K-FAC method was applicable only to fully-connected
networks, although it has been recently extended by Grosse & Martens (2016)
to handle convolutional networks as well. In this work we extend the method
to handle RNNs by introducing a novel approximation to the FIM for RNNs.
This approximation works by modelling the statistical structure between the gradient
contributions at different time-steps using a chain-structured linear Gaussian
graphical model, summing the various cross-moments, and computing the inverse
in closed form. We demonstrate in experiments that our method significantly outperforms
general purpose state-of-the-art optimizers like SGD with momentum
and Adam on several challenging RNN training tasks.

Research Areas