Manzil Zaheer
Research Areas
Authored Publications
Sort By
A Statistical Framework for Data-dependent Retrieval-Augmented Models
International Conference on Machine Learning (ICML) (2024)
Preview abstract
Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a {\em retriever} to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a {\em predictor} that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance. We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.
View details
USTAD: Unified Single-model Training Achieving Diverse Scores for Information Retrieval
Veeru Sadhanala
Sadeep Jayasumana
Aditya Menon
Rob Fergus
International Conference on Machine Learning (ICML) (2024)
Preview abstract
Modern information retrieval (IR) systems consists of multiple stages like retrieval and ranking. Transformers are employed across these different IR stages, achieving state-of-the-art performance, but each model is trained separately leading to complex pipelines and increased cost for maintaining multiple models. The apparent need for separate models is due to different input/output semantics at different stages. In this paper, we challenge this tradition of using separate models as transformers are very expressive models and ask the question would changing just score function suffice? We present a new unified approach - USTAD - to train a single network that can provide powerful ranking scores as cross-encoder (CE) as well as factorized embeddings for large-scale retrieval as a dual-encoder (DE). Empirically, we find a single USTAD model to be competitive to separate ranking CE and retrieval DE models. Furthermore, USTAD enables new distillation techniques, significantly improving CE to DE distillations. Also using USTAD teacher, we can deploy novel asymmetric architectures for student models which realizes better embedding alignment without increasing online inference cost. On standard benchmarks like MSMARCO, we show that our approach successfully distills from both dual-encoder (DE) and cross-encoder (CE) teacher models to 1/10th size asymmetric students that can retain 95-97% of the teacher performance.
View details
Teacher Guided Training: An Efficient Framework for Knowledge Transfer
Chong You
Himanshu Jain
Rob Fergus
International Conference on Learning Representations (ICLR) (2023)
Preview abstract
The remarkable performance gains realized by large pretrained models, e.g., GPT-3, hinge on the massive amounts of data they are exposed to during training. Analogously, distilling such large models to compact models for efficient deployment also necessitates a large amount of (labeled or unlabeled) training data. In this paper, we devise teacher-guided training (TGT) framework for training a high-quality compact model that leverages the knowledge acquired by pre-trained \emph{generative} models while obviating the need to go through a large volume of data. TGT exploits the fact that the teacher has acquired a good representation of the underlying data domain, which typically corresponds to a much lower dimensional manifold than the ambient space. Furthermore, we can use the teacher to explore the instance space more efficiently through sampling or gradient-based methods; thus, making TGT especially attractive for limited data or long-tail settings. We formally capture this benefit of proposed data-domain exploration in our generalization bounds. Among our empirical evaluations, we find that TGT can improve accuracy on ImageNet-LT by 10% compared to natural baseline and match accuracy on sentiment analysis on Amazon reviews without the need for pretraining.
View details
Generalization Properties of Retrieval-based Models
International Conference on Machine Learning (ICML) (2023)
Preview abstract
Many modern high-performing machine learning models such as GPT-3 primarily rely on scaling up models, e.g., transformer networks. Simultaneously, a parallel line of work aims to improve the model performance by augmenting an input instance with other (labeled) instances during inference.
Examples of such augmentations include task-specific prompts and similar examples retrieved from the training data by a nonparametric component. Remarkably, retrieval-based methods have enjoyed success on a wide range of problems, ranging from standard natural language processing and vision tasks to protein folding, as demonstrated by many recent efforts, including WebGPT and AlphaFold. Despite a growing literature showcasing the promise of these models, the theoretical underpinning for such models remains underexplored. In this paper, we present a formal treatment of retrieval-based models to characterize their generalization ability. In particular, we focus on two classes of retrieval-based classification approaches: First, we analyze a local learning framework that employs an explicit local empirical risk minimization based on retrieved examples for each input instance. Interestingly, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good overall accuracy. The second class of retrieval-based approaches we explore learns a global model using kernel methods to directly map an input instance and retrieved examples to a prediction, without explicitly solving a local learning task.
View details
A Context Integrated Transformer-based Neural Network for Auction Design
Zhijian Duan
Jingwu Tang
Yutong Yin
Xiang Yan
Xiaotie Deng
The Thirty-ninth International Conference on Machine Learning (ICML'22) (2022)
Preview abstract
One of the central problems in auction design is to develop an incentive compatible mechanism that maximizes the expected revenue. While theoretical approaches have encountered bottlenecks for multi-item auctions, recently there are many progresses of finding optimal auction through deep learning.
However, such works either focus on a fixed set of bidders and items, or restrict the auction to be symmetric. In this work, we overcome this limitation by factoring \emph{public} contextual information of bidders and items into deep learning framework. We propose $\mathtt{CITransNet}$, a context integrated transformer-based neural network for contextual auction design, which maintains permutation-equivariance over bids while being able to handle asymmetric contextual information in auctions. We show by extensive experiments that $\mathtt{CITransNet}$ can recover the known optimal analytical solutions, obtain novel mechanisms for complex multi-item auctions, and generalize to settings different from training set.
View details
Compositional Generalization and Decomposition in Neural Program Synthesis
Joey Hong
Deep Learning for Code (DL4C) Workshop at ICLR'22 (2022)
Preview abstract
When writing programs, people have the ability to tackle a new complex task by decomposing it into smaller and more familiar subtasks. While it is difficult to measure whether neural program synthesis methods have similar capabilities, what we can measure is whether they compositionally generalize, that is, whether a model that has been trained on the simpler subtasks is subsequently able to solve more complex tasks. In this paper, we focus on measuring the ability of learned program synthesizers to compositionally generalize. We first characterize several different axes along which program synthesis methods would be desired to generalize, e.g., length generalization, or the ability to combine known subroutines in new ways that do not occur in the training data. Based on this characterization, we introduce a benchmark suite of tasks to assess these abilities based on two popular existing datasets, SCAN and RobustFill. Finally, we make first attempts to improve the compositional generalization ability of Transformer models along these axes through novel attention mechanisms that draw inspiration from a human-like decomposition strategy. Empirically, we find our modified Transformer models generally perform better than natural baselines, but the tasks remain challenging.
View details
Thompson Sampling with a Mixture Prior
Joey Hong
Branislav Kveton
Mohammad Ghavamzadeh
Proceedings of The 25th International Conference on Artificial Intelligence and Statistics (AI-Stats-22) (2022), pp. 7565-7586
Preview abstract
We consider posterior sampling in online decision-making problems where the uncertain environment is sampled from a mixture distribution. We incorporate this structure in a natural way by initializing a Thompson sampling algorithm with a mixture prior. We provide a novel, general outline for analyzing the regret of Thompson sampling with a mixture prior. We also use this to derive Bayes regret bounds in both a linear bandit and tabular MDP settings. The regret bounds depend on the confidence widths of each component of the mixture prior, and converge to solely identifying the correct component when confidence widths are small. Finally, we demonstrate the empirical effectiveness of our proposed algorithm in a synthetic and real-world bandit problem involving multi-task image classification.
View details
A Fourier Approach to Mixture Learning
Mingda Qiao
Guru Prashanth Guruganesh
Conference on Neural Information Processing Systems (2022)
Preview abstract
We revisit the problem of learning mixtures of spherical Gaussians. Given samples from mixture $\frac{1}{k}\sum_{j=1}^{k}\N(\mu_j, I_d)$, the goal is to estimate the means $\mu_1, \mu_2, \ldots, \mu_k \in \R^d$ up to a small error. The hardness of this learning problem can be measured by the \emph{separation} $\Delta$ defined as the minimum distance between all pairs of means. Regev and Vijayaraghavan (2017) showed that with $\Delta = \Omega(\sqrt{\log k})$ separation, the means can be learned using $\poly(k, d)$ samples, whereas super-polynomially many samples are required if $\Delta = o(\sqrt{\log k})$ and $d = \Omega(\log k)$. This leaves open the low-dimensional regime where $d = o(\log k)$.
In this work, we give an algorithm that efficiently learns the means in $d = O(\log k/\log\log k)$ dimensions under separation $d/\sqrt{\log k}$ (modulo doubly logarithmic factors). This separation is strictly smaller than $\sqrt{\log k}$, and is also shown to be necessary. Along with the results of Regev and Vijayaraghavan (2017), our work almost pins down the critical separation threshold at which efficient parameter learning becomes possible for spherical Gaussian mixtures. This was previously open even in one dimension. More generally, our algorithm runs in time $\poly(k)\cdot f(d, \Delta, \eps)$, and is thus fixed-parameter tractable in parameters $d$, $\Delta$ and $\eps$.
Our approach is based on estimating the Fourier transform of the mixture at carefully chosen frequencies, and both the algorithm and its analysis are simple and elementary. Our positive results can be easily extended to learning mixtures of non-Gaussian distributions, under a mild condition on the Fourier spectrum of the distribution.
View details
DAG-structured Clustering by Nearest-Neighbors
Nicholas Monath
Andrew McCallum
International Conference on Artificial Intelligence and Statistics (2021)
Preview abstract
Hierarchical clusterings compactly encode multiple granularities of clusters within a tree structure. Hierarchies, by definition, fail to capture different flat partitions that are not subsumed in one another. In this paper, we advocate for an alternative structure for representing multiple alternative clusterings, a directed acyclic graph (DAG). By allowing nodes to have multiple parents, DAG structure is not only more flexible than a tree but also allows for points to be members of multiple clusters. We describe a large-scale, map-reduce-based algorithm to infer these DAGs. Our algorithm works by simply merging nearest neighbor substructures to form a DAG structure. Our algorithm is supported with theoretical guarantees showing its representational capacity over tree-based algorithms. Further, we provide comprehensive empirical experiments on large-scale clustering benchmarks and entity resolution datasets. Our results show that our method is as scalable as and more accurate than state-of-the-art tree-based techniques.
View details
A Field Guide to Federated Optimization
Jianyu Wang
Gauri Joshi
Maruan Al-Shedivat
Galen Andrew
A. Salman Avestimehr
Katharine Daly
Deepesh Data
Suhas Diggavi
Hubert Eichner
Advait Gadhikar
Antonious M. Girgis
Filip Hanzely
Chaoyang He
Samuel Horvath
Martin Jaggi
Tara Javidi
Satyen Chandrakant Kale
Sai Praneeth Karimireddy
Jakub Konečný
Sanmi Koyejo
Tian Li
Peter Richtarik
Karan Singhal
Virginia Smith
Mahdi Soltanolkotabi
Weikang Song
Sebastian Stich
Ameet Talwalkar
Hongyi Wang
Blake Woodworth
Honglin Yuan
Mi Zhang
Tong Zhang
Chunxiang (Jake) Zheng
Chen Zhu
arxiv (2021)
Preview abstract
Federated learning and analytics are a distributed approach for collaboratively learning models (or statistics) from decentralized data, motivated by and designed for privacy protection. The distributed learning process can be formulated as solving federated optimization problems, which emphasize communication efficiency, data heterogeneity, compatibility with privacy and system requirements, and other constraints that are not primary considerations in other problem settings. This paper provides recommendations and guidelines on formulating, designing, evaluating and analyzing federated optimization algorithms through concrete examples and practical implementation, with a focus on conducting effective simulations to infer real-world performance. The goal of this work is not to survey the current literature, but to inspire researchers and practitioners to design federated learning algorithms that can be used in various practical applications.
View details