Sashank Reddi

Authored Publications
Google Publications
Other Publications
    Preview abstract This paper reveals a curious observation that modern large-scale machine learning models with Transformer architectures have sparse activation maps. By activation map we refer to the intermediate output of the multi-layer perceptrons (MLPs) after a ReLU activation function, and by ``sparse'' we mean that on average very few entries (e.g., 3.0% for T5-Base and 6.3% for ViT-B16) are nonzero for each input to MLP. Through extensive experiments we demonstrate that the emergence of sparsity is a prevalent phenomenon that occurs for both natural language processing and vision tasks, on both training and evaluation data, for Transformers of various configurations, at layers of all depth levels, etc. Moreover, larger Transformers with more layers and higher MLP hidden dimensions are sparser as measured by the percentage of nonzero entries. To probe why sparsity emerges, we design experiments with random labels, random images, and infinite data, and find that sparsity may be due primarily to optimization while has little to do with the properties of training dataset. We discuss how sparsity immediately implies a means for significantly reducing the FLOP count and improving efficiency for Transformers. Moreover, we demonstrate perhaps surprisingly that explicitly enforcing an even sparser activation via Top-K thresholding with a small value of k brings a collection of desired but missing properties for Transformers, namely less sensitivity to noisy training data, more robustness to input corruptions, and better calibration for their prediction confidence. View details
    Preview abstract Large deep learning models have achieved state-of-the-art performance across various natural language processing (NLP) tasks and demonstrated remarkable few-shot learning performance. However, training them is often challenging and resource-intensive. In this paper, we study an efficient approach to train language models using few-shot learners. We show that, by leveraging the fast learning nature of few-shot learners, one can train language models efficiently in a stagewise manner. Our main insight is that stacking a good few-shot learner on a good small language model provides a good initializer for a larger language model. Using this insight and building upon progressive stacking approaches, we develop novel approaches for training such networks in a stagewise manner. Furthermore, we also provide a theoretical framework and accompanying empirical studies to support our insights, thereby creating a theoretical foundation for progressive stacking. Finally, we provide empirical results to demonstrate the effectiveness of our approach in reducing the training time of few-shot learners. View details
    Preview abstract Transformer-based models such as BERT have proven successful in information retrieval problem, which seek to identify relevant documents for a given query. There are two broad flavours of such models: cross-attention (CA) models, which learn a joint embedding for the query and document, and dual-encoder (DE) models, which learn separate embeddings for the query and document. Empirically, CA models are often found to be more accurate, which has motivated a series of works seeking to bridge this gap. However, a more fundamental question remains less explored: does this performance gap reflect an inherent limitation in the capacity of DE models, or a limitation in the training of such models? And does such an understanding suggest a principled means of improving DE models? In this paper, we study these questions, with three contributions. First, we establish theoretically that with a sufficiently large embedding dimension, DE models have the capacity to model a broad class of score distributions. Second, we show empirically that on real-world problems, DE models may overfit to spurious correlations in the training set, and thus under-perform on test samples. To mitigate this behaviour, we propose a novel distillation strategy that leverages confidence margins, and confirm its practical efficacy on the MSMARCO-Passage benchmark. View details
    A Field Guide to Federated Optimization
    Jianyu Wang
    Gauri Joshi
    Maruan Al-Shedivat
    Galen Andrew
    A. Salman Avestimehr
    Katharine Daly
    Deepesh Data
    Suhas Diggavi
    Hubert Eichner
    Advait Gadhikar
    Antonious M. Girgis
    Filip Hanzely
    Chaoyang He
    Samuel Horvath
    Zhouyuan Huo
    Martin Jaggi
    Tara Javidi
    Peter Kairouz
    Satyen Chandrakant Kale
    Sai Praneeth Karimireddy
    Jakub Konečný
    Sanmi Koyejo
    Tian Li
    Peter Richtarik
    Karan Singhal
    Virginia Smith
    Mahdi Soltanolkotabi
    Weikang Song
    Ananda Theertha Suresh
    Sebastian Stich
    Ameet Talwalkar
    Hongyi Wang
    Blake Woodworth
    Shanshan Wu
    Honglin Yuan
    Mi Zhang
    Tong Zhang
    Chunxiang (Jake) Zheng
    Chen Zhu
    Preview abstract Federated learning and analytics are a distributed approach for collaboratively learning models (or statistics) from decentralized data, motivated by and designed for privacy protection. The distributed learning process can be formulated as solving federated optimization problems, which emphasize communication efficiency, data heterogeneity, compatibility with privacy and system requirements, and other constraints that are not primary considerations in other problem settings. This paper provides recommendations and guidelines on formulating, designing, evaluating and analyzing federated optimization algorithms through concrete examples and practical implementation, with a focus on conducting effective simulations to infer real-world performance. The goal of this work is not to survey the current literature, but to inspire researchers and practitioners to design federated learning algorithms that can be used in various practical applications. View details
    Preview abstract Knowledge distillation is an approach to improve the performance of a student model by using the knowledge of a complex teacher. Despite its success in several deep learning applications, the study of distillation is mostly confined to classification settings. In particular, the use of distillation in top-k ranking settings, where the goal is to rank k most relevant items correctly, remains largely unexplored. In this paper, we study such ranking problems through the lens of distillation. We present a framework for distillation for top-k ranking and establish connections with the existing ranking methods. The core idea of this framework is to preserve the ranking at the top by matching the k largest scores of student and teacher while penalizing large scores for items ranked low by the teacher. Building on our framework, we develop a novel distillation approach, RankDistil, specifically catered towards ranking problems with a large number of items to rank. Finally, we conduct experiments which demonstrate that RankDistil yields benefits over commonly used baselines for ranking problems. View details
    A statistical perspective on distillation
    Aditya Krishna Menon
    International Conference on Machine Learning (ICML) 2021 (to appear)
    Preview abstract Knowledge distillation is a technique for improving a ``student'' model by replacing its one-hot training labels with a label distribution obtained from a ``teacher'' model. Despite its broad success, several basic questions --- e.g., Why does distillation help? Why do more accurate teachers not necessarily distill better? --- have received limited formal study. In this paper, we present a statistical perspective on distillation which provides an answer to these questions. Our core observation is that a ``Bayes teacher'' providing the true class-probabilities can lower the variance of the student objective, and thus improve performance. We then establish a bias-variance tradeoff that quantifies the value of teachers that approximate the Bayes class-probabilities. This provides a formal criterion as to what constitutes a ``good'' teacher, namely, the quality of its probability estimates. Finally, we illustrate how our statistical perspective facilitates novel applications of distillation to bipartite ranking and multiclass retrieval. View details
    Preview abstract Factorized models, such as two tower neural network models, are widely used for scoring (query, document) pairs in information retrieval tasks. These models are typically trained by optimizing the model parameters to score relevant “positive" pairs higher than the irrelevant “negative" ones. While a large set of negatives typically improves the model performance, limited computation and memory budgets place constraints on the number of negatives used during training. In this paper, we develop a novel negative sampling technique for accelerating training with softmax cross-entropy loss. By using cached (possibly stale) item embeddings, our technique enables training with a large pool of negatives with reduced memory and computation. We also develop a streaming variant of our algorithm geared towards very large datasets. Furthermore, we establish a theoretical basis for our approach by showing that updating a very small fraction of the cache at each iteration can still ensure fast convergence. Finally, we experimentally validate our approach and show that it is efficient and compares favorably with more complex, state-of-the-art approaches. View details
    Preview abstract Negative sampling is a widely adopted technique to enable efficient training in settings with a large number of classes. Typically, negative sampling approaches aim at approximating the value or gradient of the computationally expensive loss function that takes all the negative labels into account. In this work, we study the connection between negative sampling approaches and loss modification techniques for countering label imbalance. We show that different (bias) correction strategies that accompany negative sampling approaches can have unintended consequences on the model's performance on various data sub-populations. We then propose a unified approach to tackle both sampling bias, arising from working with a subset of all negative classes, and labeling bias, which is inherently present in the data due to label-imbalance. Finally, we verify our analysis and demonstrate the utility of our unified approach through empirical evaluation on standard image classification and retrieval benchmarks. View details
    Preview abstract Federated learning is a distributed machine learning paradigm in which a large number of clients coordinate with a central server to learn a model without sharing their own training data. Due to the heterogeneity of the client datasets, standard federated optimization methods such as Federated Averaging (FedAvg) are often difficult to tune and exhibit unfavorable convergence behavior. In non-federated settings, adaptive optimization methods have had notable success in combating such issues. In this work, we propose federated versions of adaptive optimizers, including Adagrad, Yogi and Adam, and analyze their convergence in the presence of heterogeneous data for general nonconvex settings. Our results highlight the interplay between client heterogeneity and communication efficiency. We also perform extensive experiments on these methods and show that the use of adaptive optimizers can improve the performance of federated learning. View details
    Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
    Ananda Theertha Suresh
    Martin Jaggi
    Sai Praneeth Karimireddy
    Satyen Kale
    Sebastian Stich
    ICML 2021(2020)
    Preview abstract Federated learning (FL) is a challenging setting for optimization due to the heterogeneity of the data across different clients which gives rise to the client drift phenomenon. In this work, we propose a general algorithmic framework, \mime, which i) mitigates client drift and ii) adapts arbitrary centralized optimization algorithms such as SGD and Adam to the federated learning setting. Mime uses a combination of control-variates and server-level statistics (e.g. momentum) at every client-update step to ensure that each local update mimics that of the centralized method run on iid data. We prove a reduction result showing that \mime can translate the convergence of a generic algorithm in the centralized setting into convergence in the federated setting. Further, we show for the first time that multiple local steps can lead to faster convergence in the cross-device FL setting. Our thorough theoretical and empirical analyses establish Mime's superiority over other other baselines. View details
    Preview abstract Transformer networks use pairwise attention to compute contextual embeddings of their inputs, and have achieved the state of the art performance in many NLP tasks. However, these models suffer from quadratic computational cost in the input sequence length $n$ to compute attention in each layer. This has prompted recent research into faster attention models, with a predominant approach involving sparsifying the connections in the attention layers. While empirically promising for long sequences, several fundamental questions remain unanswered: Can sparse transformers approximate any arbitrary sequence-to-sequence function, similar to their dense counterparts? How does the sparsity pattern and the sparsity level affect their performance? In this paper, we provide a \emph{unifying framework} that captures existing sparse attention models. Our analysis proposes sufficient conditions under which we show that a sparse attention model can provably \emph{universally approximate} any sequence-to-sequence functions. Surprisingly, our results show the existence of attention models with only $O(n)$ connections per attention layer that can approximate the same function class as the dense model with $n^2$ connections. Lastly, we present experiments comparing different patterns and levels of sparsity on standard NLP tasks. View details
    Preview abstract Attention based Transformer architecture has enabled significant advances in the field of natural language processing. In addition to new pre-training techniques, recent improvements crucially rely on working with a relatively larger embedding dimension for tokens. Unfortunately, this leads to models that are prohibitively large to be employed in the downstream tasks. In this paper we identify one of the important factors contributing to the large embedding size requirement. In particular, our analysis highlights that the scaling between the number of heads and the size of each head in the current architecture gives rise to a low-rank bottleneck in attention heads, causing this limitation, which we further validate with our experiments. As a solution we propose to set the head size of an attention unit to input sequence length, and independent of the number of heads, resulting in multi-head attention layers with provably more expressive power. We empirically show that this allows us to train models with a relatively smaller embedding dimension and with better performance scaling. View details
    Why are Adaptive Methods Good for Attention Models?
    Jingzhao Zhang
    Sai Praneeth Karimireddy
    Suvrit Sra
    Advances in Neural Information Processing Systems (NeurIPS)(2020)
    Preview abstract While stochastic gradient descent (SGD) is still the de facto algorithm in deep learning, adaptive methods like Clipped SGD/Adam have been observed to outperform SGD across important tasks, such as attention models. The settings under which SGD performs poorly in comparison to adaptive methods are not well understood yet. In this paper, we provide empirical and theoretical evidence that a heavy-tailed distribution of the noise in stochastic gradients is one cause of SGD's poor performance. We provide the first tight upper and lower convergence bounds for adaptive gradient methods under heavy-tailed noise. Further, we demonstrate how gradient clipping plays a key role in addressing heavy-tailed gradient noise. Subsequently, we show how clipping can be applied in practice by developing an adaptive coordinate-wise clipping algorithm (ACClip) and demonstrate its superior performance on BERT pretraining and finetuning tasks. View details
    Can gradient clipping mitigate label noise?
    Aditya Krishna Menon
    International Conference on Learning Representations (ICLR)(2020)
    Preview abstract Gradient clipping is a widely-used technique in the training of deep networks, and is generally motivated from an optimisation lens: informally, it controls the dynamics of iterates, thus enhancing the rate of convergence to a local minimum. This intuition has been made precise in a line of recent works, which show that suitable clipping can yield significantly faster convergence than vanilla gradient descent. In this paper, we study gradient clipping from an robustness lens: informally, one expects clipping to provide robustness to noise, since one does not overly trust any single sample. Surprisingly, we prove that gradient clipping does not in general provide robustness to label noise. On the other hand, we show that robustness is achieved by a form of loss clipping. This yields a simple, noise-robust alternative to the standard cross-entropy loss which performs well empirically. View details
    Preview abstract Despite the widespread adoption of Transformer models for NLP tasks, the expressive power of these models is not well-understood. In this paper, we establish that Transformer models are universal approximators of continuous permutation equivariant sequence-to-sequence functions with compact support, which is quite surprising given the amount of shared parameters in these models. Furthermore, using positional encodings, we circumvent the restriction of permutation equivariance, and show that Transformer models can universally approximate arbitrary continuous sequence-to-sequence functions on a compact domain. Interestingly, our proof techniques clearly highlight the different roles of the self-attention and the feed-forward layers in Transformers. In particular, we prove that fixed width self-attention layers can compute contextual mappings of the input sequences, playing a key role in the universal approximation property of Transformers. Based on this insight from our analysis, we consider other architectures that can compute contextual mappings and empirically evaluate them. View details
    Multilabel reductions: what is my loss optimising?
    Aditya Krishna Menon
    Advances in Neural Information Processing Systems (NeurIPS)(2019)
    Preview abstract Multilabel classification is a challenging problem arising in applications ranging from information retrieval to image tagging. A popular approach to this problem is to employ a reduction to a suitable series of binary or multiclass problems (e.g., computing a softmax based cross-entropy over the relevant labels). While such methods have seen empirical success, less is understood about how well they approximate two fundamental performance measures: the precision and recall@k. In this paper, we study three commonly used reductions, and two new reductions based on a normalised loss function, wherein the contribution of each instance is normalised by the number of relevant labels. A surprising outcome of our study is that that each reduction is provably consistent with respect to either precision or recall, but not both. Further, we explicate that the probability scores obtained from reductions focussed on precision must be interpreted with caution. We empirically validate our results on real-world datasets, showing in particular that our normalised loss function yields recall gains over existing reductions. View details
    Preview abstract We consider the problem of retrieving the most relevant labels for a given input when the size of the output space is very large. Retrieval methods are modeled as set-valued classifiers which output a small set of classes for each input, and a mistake is made if the label is not in the output set. Despite its practical importance, a statistically principled, yet practical solution to this problem is largely missing. To this end, we first define a family of surrogate losses and show that they are calibrated and convex under certain conditions on the loss parameters and data distribution, thereby establishing a statistical and analytical basis for using these losses. Furthermore, we identify a particularly intuitive class of loss functions in the aforementioned family and show that they are amenable to practical implementation in the large output space setting (i.e. computation is possible without evaluating scores of all labels) by developing a technique called Stochastic Negative Mining. We also provide generalization error bounds for the losses in the family. Finally, we conduct experiments which demonstrate that Stochastic Negative Mining yields benefits over commonly used negative sampling approaches. View details
    Breaking the Glass Ceiling for Embedding-Based Classifiers for Large Output Spaces
    Chuan Guo
    Ali Mousavi
    Xiang Wu
    Dan Holtmann-Rice
    Satyen Kale
    NeurIPS(2019) (to appear)
    Preview abstract In extreme classification settings, embedding-based neural network models are currently not competitive with sparse linear and tree-based methods in terms of accuracy. Most prior works attribute this poor performance to the low-dimensional bottleneck in embedding-based methods. In this paper, we demonstrate that theoretically there is no limitation to using low-dimensional embedding-based methods, and provide experimental evidence that overfitting is the root cause of the poor performance of embedding-based methods. These findings motivate us to investigate novel data augmentation and regularization techniques to mitigate overfitting. To this end, we propose GLaS, a new regularizer for embedding-based neural network approaches. It is a natural generalization from the graph Laplacian and spread-out regularizers, and empirically it addresses the drawback of each regularizer alone when applied to the extreme classification setup. With the proposed techniques, we attain or improve upon the state-of-the-art on most widely tested public extreme classification datasets with hundreds of thousands of labels. View details
    On the convergence of Adam and Beyond
    Satyen Kale
    International Conference on Learning Representations(2018)
    Preview abstract Several recently proposed stochastic optimization methods that have been successfully used in training deep networks such as RMSProp, Adam, Adadelta, Nadam are based on using gradient updates scaled by square roots of exponential moving averages of squared past gradients. In many applications, e.g. learning with large output spaces, it has been empirically observed that these algorithms fail to converge to an optimal solution (or a critical point in nonconvex settings). We show that one cause for such failures is the exponential moving average used in the algorithms. We provide an explicit example of a simple convex optimization setting where Adam does not converge to the optimal solution, and describe the precise problems with the previous analysis of Adam algorithm. Our analysis suggests that the convergence issues can be fixed by endowing such algorithms with “long-term memory” of past gradients, and propose new variants of the Adam algorithm which not only fix the convergence issues but often also lead to improved empirical performance. View details
    Preview abstract Adaptive gradient methods that rely on scaling gradients down by the square root of exponential moving averages of past squared gradients, such RMSPROP, ADAM, ADADELTA have found wide application in optimizing the non-convex problems that arise in deep learning. However, it has been recently demonstrated that such methods can fail to converge even in simple convex optimization settings. In this work, we provide a new analysis of such methods applied to nonconvex stochastic optimization problems, characterizing the effect of increasing minibatch size. Our analysis shows that under this scenario such methods do converge to stationarity up to the statistical limit of variance in the stochastic gradients (scaled by a constant factor). In particular, our result implies that increasing minibatch sizes enables convergence, thus providing a way to circumvent the non-convergence issues. Furthermore, we provide a new adaptive optimization algorithm, YOGI, which controls the increase in effective learning rate, leading to even better performance with similar theoretical guarantees on convergence. Extensive experiments show that YOGI with very little hyperparameter tuning outperforms methods such as ADAM in several challenging machine learning tasks View details
