Ankit Singh Rawat
Research Areas
Authored Publications
Sort By
DistillSpec: Improving speculative decoding via knowledge distillation
Yongchao Zhou
Kaifeng Lyu
Aditya Menon
Jean-François Kagy
International Conference on Learning Representations (ICLR) (2024)
Preview abstract
Speculative decoding proves highly effective in expediting Large Language Model inference by employing a smaller draft model for token generation and a larger model for parallel token verification. Nonetheless, identifying an accurate and compact draft model aligned with the target model presents challenges. To address this, we propose leveraging white-box knowledge distillation, significantly improving draft model alignment with the larger target model, thereby enhancing speculative decoding. Our findings underscore the pivotal role of on-policy data generation and a suitable divergence function tailored to the task and decoding scheme for successful distillation. In practice, our refined distillation approach yields 20\% speedup over standard speculative decoding across five distinct tasks, using both greedy decoding and temperature sampling. Furthermore, we extend the concept of lossless speculative decoding to incorporate a lenience factor in the rejection sampling step, offering fine-grained control over the trade-off between quality and latency in lossy decoding. Finally, adopting a strategy of "distilling for performance first and distillation for speculative decoding second" enables a remarkable 8x reduction in latency with minimal performance compromise, compared to no distillation and speculative decoding baseline.
View details
Dual-Encoders for Extreme Multi-label Classification
Nilesh Gupta
Devvrit Khatri
Inderjit Dhillon
International Conference on Learning Representations (ICLR) (2024)
Preview abstract
Dual-encoder models have demonstrated significant success in dense retrieval tasks for open-domain question answering that mostly involves zero-shot and few-shot scenarios. However, their performance in many-shot retrieval problems, such as extreme classification, remains largely unexplored. State-of-the-art extreme classification techniques like NGAME use a combination of dual-encoders and a learnable classification head for each class to excel on these tasks. Existing empirical evidence shows that, for such problems, the dual-encoder method's accuracies lag behind the performance of the SOTA extreme classification methods that grow the number of learnable parameters with the number of classes. In this work, we investigate the potential reasons behind this observed gap, such as the intrinsic capacity limit due to fixed model size for dual-encoder models that is independent of the numbers of classes, training, loss formulation, negative sampling, etc. We methodically experiment on these different axes and find that model size is not the main bottleneck, but rather the training and loss formulation. When trained correctly even small dual-encoders can outperform State-of-the-art extreme classification methods by up to 2% at Precision on million label scale extreme classification datasets, while being 20x smaller in terms of the number of trainable parameters. We further propose a differentiable top-k error-based loss function, which can be used to specifically optimize for recall@k metrics.
View details
Understanding Self-Attention through Prompt-Conditioned Markov Chains
Muhammed Emrullah Ildiz
Yixiao Huang
Yingcong Li
Samet Oymak
International Conference on Machine Learning (ICML) (2024)
Preview abstract
Modern language models rely on the transformer architecture and self-attention mechanism to perform language understanding and text generation. In this work, we study learning a 1-layer self-attention model from a set of prompts and associated output data sampled according to ground-truth weights. As our main contribution, we establish a precise mapping between a self-attention model and a Markov chain through a convex problem formulation: Inputting a prompt to the model samples the output token according to a prompt-conditioned Markov chain which weights the transitions of a base chain. Additionally, incorporating positional encoding results in position-dependent scaling of the chain transitions. Building on this formalism, we develop identifiability/coverage conditions for data distribution that guarantee consistent estimation and establish sample complexity guarantees under IID sampled data. Finally, we study the challenging problem of learning from a single dependent trajectory generated from an initial prompt. Unlike standard Markov chains, we characterize a winner-takes-all phenomenon where the sampling process degenerates into generating a limited subset of tokens due to the non-mixing nature of the attention layer. We argue that this phenomenon explains the tendency of modern LLMs to generate repetitive text and makes consistent estimation from a single-trajectory intricate and problem-dependent -- which we provide a preliminary characterization of.
View details
Think before you speak: Training language models with pause tokens
Sachin Goyal
Ziwei Ji
Aditya Menon
Vaishnavh Nagarajan
International Conference on Learning Representations (ICLR) (2024)
Preview abstract
The present-day language model generates its response by producing a series of tokens in immediate succession: the $K+1$th token is an outcome of manipulating exactly $K$ hidden values in each layer corresponding to each of the $K$ previous tokens. Is it possible to somehow allow the model to manipulate more hidden values before committing to an answer? If yes, would this help? We explore these questions by training models with learnable \textit{pause} tokens. Besides feeding the usual prefix to the model, our idea is to feed the model with an additional sequence of pause tokens. On these tokens, the model's output is ignored all the way until the last pause token, where we begin extracting the answer. We explore this idea of ``delayed answering'' in a 1B model, where we consider both pre-training and/or fine-tuning with pause tokens. We find that while merely finetuning a standard model is not very helpful, pause-pretrained models shows promise on some downstream tasks such as GSM (reasoning) and Squad, CommonSenseQA and Lambada (question-answering tasks). We also conduct various ablations to explore the effect of the number of pause tokens. While our work takes a preliminary exploration in delayed computations for language models by focusing on a 1B model, we hope it inspires future work that can make this idea practically feasible without pre-training and for models trained with other pretraining objectives and other sizes.
View details
Language Model Cascades: Token-Level Uncertainty And Beyond
Neha Gupta
Aditya Menon
International Conference on Learning Representations (2024)
Preview abstract
Recent advances in language model (LM) design has yielded a series of models with remarkably improved quality on complex NLP tasks, but significantly in-creased inference cost. A simple strategy to achieve more favourable cost-quality tradeoffs is cascading: here, a small model is invoked for most “easy” instances, while a large model is invoked for a few “hard” instances. Typically, “easy” in-stances are those where the small model has high confidence in its prediction.While the principles underpinning effective cascading are well-studied for classification problems, a similar understanding is lacking for generative tasks. The ex-tension of simple ”Chow” rule which defers based on the probability of predicting an answer is not straightforward for generative tasks where the number of output tokens is variable. Moreover, LMs are known to suffer from length bias where longer answers are penalized more as compared to shorter answers which complicates things further. In this work, we initiate a systematic study of deferral rules for cascades for language models. For example, how does one best summarise model confidence across a variable number of output tokens? We show experimentally that there is no one straight forward extension of probability based uncertainty for LMs which works well across all tasks. Via experiments on a range of bench-marks with FLAN-T5 models, we find that incorporating token-level uncertainty can significantly improve the cost-quality tradeoff of cascades. We further show that incorporating embeddings from the smaller model and intermediate layer embeddings from the larger model can further boost performance
View details
Mechanics of Next Token Prediction with Transformers
Yingcong Li
Yixiao Huang
Muhammed Emrullah Ildiz
Samet Oymak
International Conference on Artificial Intelligence and Statistics (AISTATS) (2024)
Preview abstract
Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this seemingly simple training objective, they have revolutionized natural language processing within a short timeframe. Underlying this success is the self-attention mechanism. In this work, we ask: What does 1-layer self-attention learn from next-token prediction? We show that when trained with gradient descent, self-attention implements a simple automaton that induces a token hierarchy induced by the training data. Concretely, from the (sequence, label) pairs of the training data, we construct directed next-token graphs (NTGs) of the dataset that capture (input token, label) relations. We find that implicit bias of self-attention is captured by the strongly-connected components (SCCs) which partitions the NTGs into cyclic and acyclic subgraphs: Acyclic subgraph results in an SVM direction that enforces the priority order among SCCs. Cyclic subgraph yields a correction term that allocates the nonzero softmax probabilities among tokens within the same SCC. We empirically and theoretically demonstrate that superposition of these components can accurately predict the implicit bias of gradient descent in next-token prediction. We believe these results shed light on self-attention's ability to process sequential data and pave the path towards demystifying more complex transformer architectures.
View details
A Statistical Framework for Data-dependent Retrieval-Augmented Models
International Conference on Machine Learning (ICML) (2024)
Preview abstract
Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a {\em retriever} to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a {\em predictor} that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance. We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.
View details
USTAD: Unified Single-model Training Achieving Diverse Scores for Information Retrieval
Veeru Sadhanala
Sadeep Jayasumana
Aditya Menon
Rob Fergus
International Conference on Machine Learning (ICML) (2024)
Preview abstract
Modern information retrieval (IR) systems consists of multiple stages like retrieval and ranking. Transformers are employed across these different IR stages, achieving state-of-the-art performance, but each model is trained separately leading to complex pipelines and increased cost for maintaining multiple models. The apparent need for separate models is due to different input/output semantics at different stages. In this paper, we challenge this tradition of using separate models as transformers are very expressive models and ask the question would changing just score function suffice? We present a new unified approach - USTAD - to train a single network that can provide powerful ranking scores as cross-encoder (CE) as well as factorized embeddings for large-scale retrieval as a dual-encoder (DE). Empirically, we find a single USTAD model to be competitive to separate ranking CE and retrieval DE models. Furthermore, USTAD enables new distillation techniques, significantly improving CE to DE distillations. Also using USTAD teacher, we can deploy novel asymmetric architectures for student models which realizes better embedding alignment without increasing online inference cost. On standard benchmarks like MSMARCO, we show that our approach successfully distills from both dual-encoder (DE) and cross-encoder (CE) teacher models to 1/10th size asymmetric students that can retain 95-97% of the teacher performance.
View details
Towards Understanding the Role of Attention in Prompt-tuning
Samet Oymak
Mahdi Soltanolkotabi
Christos Thrampoulidis
International Conference on Machine Learning (ICML) (2023)
Preview abstract
Prompt-tuning is an emerging strategy to adapt large language models (LLM) to downstream tasks by learning a (soft-)prompt parameter from data. Despite its success in LLMs, there is limited theoretical understanding of the power of prompt-tuning and the role of the attention mechanism in prompting. In this work, we explore prompt-tuning for one-layer attention architectures and study contextual mixture-models where each input token belongs to a context-relevant or -irrelevant set. We isolate the role of prompt-tuning through a self-contained prompt-attention model. Our contributions are as follows: (1) We show that softmax-prompt-attention is provably more expressive than softmax-self-attention and linear-prompt-attention under our contextual data model. (2) We analyze the initial trajectory of gradient descent and show that it learns the prompt and prediction head with near-optimal sample complexity and demonstrate how prompt can provably attend to sparse context-relevant tokens. (3) Assuming a known prompt but an unknown prediction head, we characterize the exact finite sample performance of prompt-attention which reveals the fundamental performance limits and the precise benefit of the context information. We also provide experiments that verify our theoretical insights on real datasets and demonstrate how prompt-tuning enables the model to attend to context-relevant information.
View details
Serving Graph Compression for Graph Neural Networks
Cho-Jui Hsieh
International Conference on Learning Representations (ICLR) (2023)
Preview abstract
Serving a GNN model in online applications is challenging --- one has to propagate the information from training nodes to testing nodes to achieve the best performance, while storing the whole training set (including training graph and node features) during inference time is prohibitive for most of the real world applications. We tackle this serving space compression problem in the paper, where the goal is to compress the storage requirement for GNN serving. Given a model to be served, the proposed method constructs a small set of virtual representative nodes to replace the original training nodes, so that users just need to replace the original training set by this virtual representative set to reduce the space requirement for serving, without the need of changing the actual GNN model and the forward pass.
We carefully analyze the error in the forward pass and derive simple ways to construct the node features and graph of virtual representative nodes to minimize the approximation error. Experimental results demonstrate that the proposed method can significantly reduce the serving space requirement for GNN inference.
View details