Dheeraj Mysore Nagaraj

Dheeraj Mysore Nagaraj

I am a Research Scientist at Google AI, Bangalore, India in the machine learning and optimization (MLO) team. I work on various topics in theoretical machine learning, applied probability and statistics. My current work focuses on designing learning algorithms for data with temporal dependence, random graphs and stochastic optimization algorithms. I recently completed my PhD at Lab for Information and Decision Systems (LIDS) at MIT advised by Prof. Guy Bresler. Prior to that I was an undergraduate student at IIT Madras.
Authored Publications
Sort By
  • Title
  • Title, descending
  • Year
  • Year, descending
    Preview abstract In this work, we consider the problem of collaborative multi-user reinforcement learning. In this setting there are multiple users with the same state-action space and transition probabilities but with different rewards. Under the assumption that the reward matrix of the N users has a low-rank structure -- a standard and practically successful assumption in the offline collaborative filtering setting-- the question is can we design algorithms with significantly lower sample complexity compared to the ones that learn the MDP individually for each user. Our main contribution is an algorithm which explores rewards collaboratively with N user-specific MDPs and can learn rewards efficiently in two key settings: tabular MDPs and linear MDPs. When N is large and the rank is constant, the sample complexity per MDP depends logarithmically over the size of the state-space, which represents an exponential reduction (in the state-space size) when compared to the standard ``non-collaborative'' algorithms. View details
    Preview abstract We consider regression where the noise distribution depends on the covariates (i.e, heteroscedastic noise), which captures popular settings such as linear regression with multiplicative noise occuring due to covariate uncertainty. In particular we consider linear regression where the noise variance is an unknown rank-1 quadratic function of the covariates. While an application of least squares regression can achieve an error rate of $\nicefrac{d}{n}$, this ignores the fact that the magnitude of the noise can be very small for certain values of the covariates, which can aid faster learning. Our algorithm \ouralg~runs a parameter estimation algorithm and a noise distribution model learning algorithm are run alternately, using each other's outputs to iteratively obtain better estimates of the parameter and the noise distribution model respectively. This achieves an error rate of $\nicefrac{1}{n} + \nicefrac{d^2}{n^2}$, which we show is minimax optimal up to logarithmic factors. A sub-routine for \ouralg~performs phase estimation with multiplicative noise maybe of independent interest. View details
    Preview abstract We study the finite-time behaviour of the popular temporal difference (TD) learning algorithm, when combined with tail-averaging. We derive finite time bounds on the parameter error of the tail-averaged TD iterate under a step-size choice that does not require information about the eigenvalues of the matrix underlying the projected TD fixed point. Our analysis shows that tail-averaged TD converges at the optimal $O\left(1/t\right)$ rate, both in expectation and with high probability. In addition, our bounds exhibit a sharper rate of decay for the initial error (bias), which is an improvement over averaging all iterates. We also propose and analyze a variant of TD that incorporates regularization. From our finite time analysis, we conclude that the regularized version of TD is useful for problems with ill-conditioned features. View details
    Preview abstract We consider stochastic approximations of sampling algorithms such as Unadjusted Langevin Algorithm (ULA) and Interacting Particle Dynamics (IPD), using random batches. The noise added by the random batches is near-Gaussian due to the central limit theorem (CLT) while the driving Brownian motion is exactly Gaussian. Using this structure, we show that the error produced by the stochastic approximation can be hidden inside the diffusion process driving the algorithm in order to obtain convergence guarantees. This method also leads to a new algorithm: the covariance corrected random batch method, which corrects for the additional noise from the random batches to give us faster convergence. To summarize our contribution: (1) We show first non-exploding, KL convergence bounds for SGLD with significantly fewer assumptions and better dimension dependence (improvement from $d^4$ to $d^{1.5}$). We show that covariance corrected SGLD and demonstrate that it enjoys even faster convergence. (2) For IPD, we analyze covariance corrected random batch methods. Under fewer assumptions, we remove the exponential dependence on the horizon observed in prior works relating to random batch methods. View details
    Preview abstract Stein Variational Gradient Descent (SVGD) is a popular nonparametric variational inference algorithm which simulates an interacting particle system to approximate a target distribution. While SVGD has demonstrated promising empirical performance across various domains, and its population (i.e, infinite-particle) limit dynamics is well studied, the behavior of SVGD in the finite-particle regime is much less understood. In this work, we design two computationally efficient variants of SVGD, namely VP-SVGD and RB-SVGD, with provably fast finite-particle convergence rates. By introducing the notion of \emph{virtual particles}, we develop novel stochastic approximations of population-limit SVGD dynamics in the space of probability measures, which is exactly implementable using only a finite number of particles. Our algorithms can be viewed as specific random-batch approximations of SVGD, which are computationally more efficient than ordinary SVGD. We establish that the $n$ particles output by VP-SVGD and RB-SVGD, run for $T$ steps, are i.i.d samples from a distribution whose Kernel Stein Discrepancy to the target is at most $O(T^{\nicefrac{-1}{6}})$ under standard assumptions. Our results hold under a mild growth condition on the potential function, which is significantly weaker than the isoperimetric assumptions (e.g. Poincare Inequality) or information-transport conditions (e.g. Talagrand's Inequality $\mathsf{T}_1$) generally considered in prior works. As a corollary, we consider the convergence of the empirical measure (of the particles output by VP-SVGD and RB-SVGD) to the target distribution and demonstrate a \emph{double exponential improvement} over the best known finite-particle analysis of SVGD. View details
    Metastable Mixing of Markov Chains: Efficiently Sampling Low Temperature Exponential Random Graphs
    Guy Bresler
    Eshaan Nichani
    Annals of Applied Probability (AAP) (2023) (to appear)
    Preview abstract In this paper we consider the problem of sampling from the low-temperature exponential random graph model (ERGM). The usual approach is via Markov chain Monte Carlo, but strong lower bounds have been established for the ERGM showing that any local Markov chain suffers from an exponentially large mixing time due to meta-stable states. We instead consider meta-stable mixing, a notion of approximate mixing within a collection of meta-stable states. In the case of the ERGM, we show that Glauber dynamics with the right $G(n,p)$ initialization has a \stable~mixing time of $O(n^2\log n)$ to within total variation distance $\exp(-\Omega(n))$. View details
    Preview abstract We study the problem of planning restless multi-armed bandits (RMABs) with multiple actions. This is a popular model for multi-agent systems with applications like multi-channel communication, monitoring and machine maintenance tasks, and healthcare. Whittle index policies, which are based on Lagrangian relaxations, are widely used in these settings due to their simplicity and near-optimality under certain conditions. In this work, we first show that Whittle index policies can fail in simple and practically relevant RMAB settings, even when the RMABs are indexable. We further discuss why the Whittle index policies can provably fail in these settings, despite indexability and how even asymptotic optimality does not translate well to practically relevant planning horizons. We then propose an alternate planning algorithm based on the mean-field method, which borrows ideas from existing research with some improvements. This algorithm can provably and efficiently obtain near-optimal policies when the number of arms, $N$, is large without the stringent structural assumptions required by Whittle index policies. Our approach is hyper-parameter free, and we provide an improved non-asymptotic analysis which has a) a better dependence on problem dependent parameters b) high probability upper bounds which show that the reward of the policy is reliable c) matching lower bounds for this algorithm, thus demonstrating the tightness of our bounds. Our extensive experimental analysis shows that the mean-field approach matches or outperforms other baselines. View details