Behnam Neyshabur
I am a senior staff research scientist at Google. Before that, I was a postdoctoral researcher at New York University and a member of Theoretical Machine Learning program at Institute for Advanced Study (IAS) in Princeton. In summer 2017, I received a PhD in computer science at TTI-Chicago where I was fortunate to be advised by Nati Srebro.
My current primary interest is reasoning and algorithmic capabilities of large language models but I have also not lost my interest in science of deep learning and (out-of-distribution) generalization.
Research Areas
Authored Publications
Sort By
A Loss Curvature Perspective On Training Instability in Deep Learning
Justin Gilmer
Behrooz Ghorbani
Ankush Garg
David Cardoze
ICLR (2022)
Preview abstract
In this work, we study the evolution of the loss Hessian across many classification tasks in order to understand the effect the curvature of the loss has on the training dynamics. Whereas prior work has focused on how different learning rates affect the loss Hessian observed during training, we also analyze the effects of model initialization, architectural choices, and common training heuristics such as gradient clipping and learning rate warmup. Our results demonstrate that successful model and hyperparameter choices allow the early optimization trajectory to either avoid---or navigate out of---regions of high curvature and into flatter regions that tolerate a higher learning rate. Our results suggest a unifying perspective on how disparate mitigation strategies for training instability ultimately address the same underlying failure mode of neural network optimization, namely poor conditioning. Inspired by the conditioning perspective, we show that learning rate warmup can improve training stability just as much as batch normalization, layer normalization, MetaInit, GradInit, and Fixup initialization.
View details
Preview abstract
Understanding the loss landscape of deep neural networks has been the subject of many studies due to its close connections to optimization and generalization. Prior work has shown that there is often a performance barrier along the linear interpolation of the weights of two models trained with different initial seeds. In this work, we first empirically investigate how different model parameters and data distributions impact such performance barriers. Next, we consider the invariances in the function space of neural networks that arise from permutation of hidden units. We investigate this through extensive experiments and provide several pieces of evidence that if these invariances are taken into account, many of the barriers vanish.
View details
Preview abstract
Recent developments in large-scale machine learning have created a tempting picture suggesting that by scaling up data, model size and training time properly, one can obtain a model that can be used successfully in few-shot settings in all downstream tasks. In this work, we investigate this premise empirically and provide a strong case against it. In particular, we consider image recognition task with large scale models (Vision Transformers) trained on the largest scale of available data (JFT). We show that as we improve the performance of upstream task either by scaling up or hyper-parameter and architectural choices, the performance of many downstream tasks eventually plateau. We showcase an even more extreme scenario where performance on upstream and downstream contradict each other, i.e., in order to have a better downstream performance, we need to hurt upstream accuracy. We delve deeper into understanding the reasons that give rise to these phenomena by designing interventions and investigating different components of the models which gives us crude yet useful insights into the mechanisms behind these observations.
View details
Preview abstract
Distribution shift is a prevalent problem in the real-world deployment of machine learning models. Typically a mismatch between the source (training) and target (test) distribution leads to a gap between the source and target performance of the model. In this work, we investigate methods that leverage only unlabeled target data to predict accuracy under distribution shift. We propose a simple and effective method called Average Thresholded Confidence (ATC) that learns a scalar \emph{threshold} on model confidence on source data and predicts model performance as the average number of unlabeled target examples above the identified threshold. ATC outperforms previous approaches across several model architectures and various types of distribution shifts (e.g. synthetic corruptions, shifts due to dataset reproduction, or shifts due to novel subpopulations) applied to FMoW-\textsc{wilds}, ImageNet, CIFAR, and MNIST datasets. ATC estimates target performance up to $2\text{--}3\times$ more accurately compared to recently proposed methods. Finally, we theoretically analyze our proposed method on a toy distribution shift model with varying degrees of spurious correlation.
View details
Exploring Length Generalization in Large Language Models
Cem Anil
Yuhuai Wu
Aitor Lewkowycz
Guy Gur-Ari
NeurIPS Oral (2022)
Preview abstract
The ability to extrapolate from short problem instances to longer ones is an important form of out-of-distribution generalization in reasoning tasks, and is crucial when learning from datasets where longer problem instances are rare. These include theorem proving, solving quantitative mathematics problems, and reading/summarizing novels. In this paper, we run careful empirical studies exploring the length generalization capabilities of transformer-based language models. We first establish that naively finetuning transformers on length generalization tasks shows significant generalization deficiencies independent of model scale. We then show that combining pretrained large language models' in-context learning abilities with scratchpad prompting (asking the model to output solution steps before producing an answer) results in a dramatic improvement in length generalization. We run careful failure analyses on each of the learning modalities and identify common sources of mistakes that highlight opportunities in equipping language models with the ability to generalize to longer problems.
View details
Solving Quantitative Reasoning Problems with Language Models
Aitor Lewkowycz
David Martin Dohan
Henryk Michalewski
Cem Anil
Imanol Schlag
Theo Gutman-Solo
Yuhuai Wu
Guy Gur-Ari
NeurIPS (2022)
Preview abstract
Language models have achieved remarkable performance on a wide range of tasks that require natural language understanding. Nevertheless, state-of-the-art models have generally struggled with tasks that require quantitative reasoning, such as solving mathematics, science, and engineering problems at the college level. To help close this gap, we introduce Minerva, a large language model pretrained on general natural language data and further trained on technical content. The model achieves state-of-the-art performance on technical benchmarks without the use of external tools. We also evaluate our model on over two hundred undergraduate-level problems in physics, biology, chemistry, economics, and other sciences that require quantitative reasoning, and find that the model can correctly answer nearly a third of them.
View details
Preview abstract
We introduce the Block-Recurrent Transformer, which applies a transformer layer in a recurrent fashion along a sequence, and has linear complexity with respect to sequence length. Our recurrent cell operates on blocks of tokens rather than single tokens, and leverages parallel computation within a block in order to make efficient use of accelerator hardware. The cell itself is strikingly simple. It is merely a transformer layer: it uses self-attention and cross-attention to efficiently compute a recurrent function over a large set of state vectors and tokens. Our design was inspired in part by LSTM cells, and it uses LSTM-style gates, but it scales the typical LSTM cell up by several orders of magnitude.
Our implementation of recurrence has the same cost in both computation time and parameter count as a conventional transformer layer, but offers dramatically improved perplexity in language modeling tasks over very long sequences. Our model out-performs a long-range Transformer XL baseline by a wide margin, while running twice as fast. We demonstrate its effectiveness on PG19 (books), arXiv papers, and GitHub source code.
View details
Preview abstract
The remarkable progress in deep learning in recent years is largely driven by improvements in scale, where bigger models are trained on larger datasets for longer schedules. To predict the benefit of scale empirically, we argue for a more rigorous methodology based on the extrapolation loss, instead of reporting the best-fitting (interpolating) parameters. We then present a recipe for estimating scaling law parameters reliably from learning curves. We demonstrate that it extrapolates more accurately than previous methods in a wide range of architecture families across several domains, including image classification, neural machine translation (NMT) and language modeling, in addition to tasks from the BIG-Bench evaluation benchmark. Finally, we release a benchmark dataset comprising of 90 evaluation tasks to facilitate research in this domain.
View details
Preview abstract
State space models have shown to be effective for modeling long range dependencies, specifically on sequence classification tasks. In this paper we focus on autoregressive sequence modeling over natural language, Github code and ArXiv mathematics articles. Based on a few recent developments around effectiveness of gated activation functions, we propose a new layer, named Gated State Space (GSS) layer. We show that GSS trains significantly faster than the diagonal version of S4 (i.e. DSS) on TPUs, is simple to implement and fairly competitive with several well-tuned Transformer-based baselines. Finally, we show that interleaving traditional Transformer blocks with GSS improves performance even further.
View details
Avoiding Spurious Correlations: Bridging Theory and Practice
Thao Nguyen
Vaishnavh Nagarajan
NeurIPS 2021 Workshop on Distribution Shifts: Connecting Methods and Applications
Preview abstract
Distribution shifts in the wild jeopardize the performance of machine learning models as they tend to pick up spurious correlations during training. Recent work \cite{nagarajan2020understanding} has characterized two specific failure modes of out-of-distribution (OOD) generalization, and we extend this theoretical framework by interpreting existing algorithms as solutions to these failure modes. We then evaluate them on different image classification datasets, and in the process surface two issues that are central to existing robustness techniques. For those that rely on group annotations, we show how the group information in standard benchmark datasets is unable to fully capture the spurious correlations present. For those that don't require group annotations, the validation set utilized for model selection still carries assumptions that are not realistic in real-world settings, and we show how this choice of shifts in validation set could impact performance of different OOD algorithms.
View details