Roy Frostig
Research Areas
Authored Publications
Sort By
Preview abstract
Automatic differentiation (AD) is conventionally understood as a family of distinct algorithms, rooted in two "modes" -- forward and reverse -- which are typically presented (and implemented) separately. Can there be only one? Following up on the AD systems developed in the JAX and Dex projects, we formalize a decomposition of reverse-mode AD into (i) forward-mode AD followed by (ii) unzipping the linear and non-linear parts and then (iii) transposition of the linear part.
To that end, we define a (substructurally) linear type system that can prove a class of functions are (algebraically) linear. Our main results are that forward-mode AD produces such linear functions, and that we can unzip and transpose any such linear function, conserving cost, size, and linearity. Composing these three transformations recovers reverse-mode AD. This decomposition also sheds light on checkpointing, which emerges naturally from a free choice in unzipping `let` expressions. As a corollary, checkpointing techniques are applicable to general-purpose partial evaluation, not just AD.
We hope that our formalization will lead to a deeper understanding of automatic differentiation and that it will simplify implementations, by separating the concerns of differentiation proper from the concerns of gaining efficiency (namely, separating the derivative computation from the act of running it backward).
View details
Preview abstract
Excessive reuse of holdout data can lead to overfitting. However, there is little concrete evidence of significant overfitting due to holdout reuse in popular multiclass benchmarks today.
Known results show that, in the worst-case, revealing the accuracy of $k$ adaptively chosen classifiers on a data set of size $n$ allows one to create a classifier with bias of $\Theta(\sqrt{k/n})$ for any binary prediction problem. We show a new upper bound of $\tilde O(\max\{\sqrt{k\log(n)/(mn)},k/n\})$ on the worst-case bias that any attack can achieve in a prediction problem with $m$ classes. Moreover, we present an efficient attack that achieves a bias of $\Omega(\sqrt{k/(m^2 n)})$ and improves on previous work for the binary setting ($m=2$). We also present an inefficient attack that achieves a bias of $\tilde\Omega(k/n)$. Complementing our theoretical work, we give new practical attacks to stress-test multiclass benchmarks by aiming to create as large a bias as possible with a given number of queries. Our experiments show that the additional uncertainty of prediction with a large number of classes indeed mitigates the effect of our best attacks.
Our work extends developments in understanding overfitting due to adaptive data analysis to multiclass prediction problems. It also bears out the surprising fact that multiclass prediction problems are significantly more robust to overfitting when reusing a test (or holdout) dataset. This offers an explanation as to why popular multiclass prediction benchmarks, such as ImageNet, may enjoy a longer lifespan than what intuition from literature on binary classification suggests.
View details
Measuring the Effects of Data Parallelism on Neural Network Training
Chris Shallue
Jaehoon Lee
Jascha Sohl-dickstein
Journal of Machine Learning Research (JMLR) (2018)
Preview abstract
Recent hardware developments have made unprecedented amounts of data parallelism available for accelerating neural network training. Among the simplest ways to harness next-generation accelerators is to increase the batch size in standard mini-batch neural network training algorithms. In this work, we aim to experimentally characterize the effects of increasing the batch size on training time, as measured by the number of steps necessary to reach a goal out-of-sample error. Eventually, increasing the batch size will no longer reduce the number of training steps required, but the exact relationship between the batch size and how many training steps are necessary is of critical importance to practitioners, researchers, and hardware designers alike. We study how this relationship varies with the training algorithm, model, and data set and find extremely large variation between workloads. Along the way, we reconcile disagreements in the literature on whether batch size affects model quality. Finally, we discuss the implications of our results for efforts to train neural networks much faster in the future.
View details
Preview abstract
We describe JAX, a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs. JAX uses the XLA compiler infrastructure to generate optimized code for the program subroutines that are most favorable for acceleration, and these optimized subroutines can be called and orchestrated by arbitrary Python. Because the system is fully compatible with Autograd, it allows forward- and reverse-mode automatic differentiation of Python functions to arbitrary order. Because JAX supports structured control flow, it can generate code for sophisticated machine learning algorithms while maintaining high performance. We show that by combining JAX with Autograd and Numpy we get an easily programmable and highly performant ML system that targets CPUs, GPUs, and TPUs, capable of scaling to multi-core Cloud TPUs.
View details
Preview abstract
We describe and analyze a simple random feature scheme (RFS) from prescribed compositional kernels. The compositional kernels we use are inspired by the structure of convolutional neural networks and kernels. The resulting scheme yields sparse and efficiently computable features. Each random feature can be represented as an algebraic expression over a small number of (random) paths in a composition tree. Thus, compositional random features can be stored compactly. The discrete nature of the generation process enables de-duplication of repeated features, further compacting the representation and increasing the diversity of the embeddings. Our approach complements and can be combined with previous random feature schemes.
View details