Xin Wang
Xin Wang is in the Algorithms team at Google Research. Xin finished his PhD in Mathematics at Georgia Institute of Technology before coming to Google. Xin's research interests are in efficient computing, memory mechanism for machine learning, and optimization.
Research Areas
Authored Publications
Sort By
Preview abstract
Due to the size and complexity of modern large language models (LLMs), it has proven challenging to uncover the underlying mechanisms that models use to solve reasoning problems. For instance, is their reasoning for a specific problem localized to certain parts of the network? Do they break down the reasoning problem into modular components that are then executed as sequential steps as we go deeper in the model? To better understand the reasoning capability of LLMs, we study a minimal propositional logic problem that requires combining multiple facts to arrive at a solution. By studying this problem on Mistral and Gemma models, up to 27B parameters, we illuminate the core components the models use to solve such logic problems. From a mechanistic interpretability point of view, we use causal mediation analysis to uncover the pathways and components of the LLMs' reasoning processes. Then, we offer fine-grained insights into the functions of attention heads in different layers. We not only find a sparse circuit that computes the answer, but we decompose it into sub-circuits that have four distinct and modular uses. Finally, we reveal that three distinct models -- Mistral-7B, Gemma-2-9B and Gemma-2-27B -- contain analogous but not identical mechanisms.
View details
Preview abstract
Large language models (LLMs) have demonstrated remarkable performance in tasks that require reasoning abilities. Motivated by recent works showing evidence of LLMs being able to plan and reason on abstract reasoning problems in context, we conduct a set of controlled experiments on a synthetic propositional logic problem to provide a mechanistic understanding of how such abilities arise. In particular, for a decoder-only transformer trained solely on our synthetic dataset,
we identify the specific mechanisms by which a three-layer Transformer solves the reasoning task. In particular, we identify certain ``planning'' and reasoning circuits which require cooperation between the attention blocks to in totality implement the desired reasoning algorithm.
We also find that deeper models with greater number of attention heads exhibit a stronger performance on solving more complex variants of our logic problem.
View details
Preview abstract
Causal language modeling using the Transformer architecture has yielded remarkable capabilities in Large Language Models (LLMs) over the last few years. However, the extent of the search and reasoning abilities of LLMs remains a topic of ongoing debate. In this work, we study if causal language modeling with Transformers can learn a complex task such as solving Sudoku puzzles. To solve a Sudoku puzzle, the model is first required to search over all empty cells of the puzzle to decide on a cell to fill and then apply an appropriate strategy to fill the decided cell. Sometimes, the application of a strategy only results in thinning down the possible values in a cell rather than concluding the exact value of the cell. In such cases, multiple strategies are applied one after the other to fill a single cell. We observe that Transformer models trained on this synthetic task can indeed learn to solve Sudokus when trained on a logical sequence of steps taken by a solver. We find that training Transformers with the logical sequence of steps is necessary and without such training, they fail to learn Sudoku. In addition, we study the internal representations of the trained Transformer and find that through linear probing we can decode high-level strategy information from them pointing to the presence of a strong reasoning engine implicit in the Transformer weights.
View details
How Transformers Solve Propositional Logic Problems: A Mechanistic Analysis
Guanzhe Hong
Nishanth Dikkala
Enming Luo
The 4th Workshop on Mathematical Reasoning and AI @ NeurIPS 2024
Preview abstract
Large language models (LLMs) have shown amazing performance on tasks that require planning and reasoning. Motivated by this, we investigate the internal mechanisms that underpin a network's ability to perform complex logical reasoning. We first construct a synthetic propositional logic problem that serves as a concrete test-bed for network training and evaluation. Crucially, this problem demands nontrivial planning to solve, but we can train a small transformer to achieve perfect accuracy. Building on our set-up, we then pursue an understanding of precisely how a three-layer transformer, trained from scratch, solves this problem. We are able to identify certain "planning" and "reasoning" circuits in the network that necessitate cooperation between the attention blocks to implement the desired logic. To expand our findings, we then study a larger model, Mistral 7B. Using activation patching, we characterize internal components that are critical in solving our logic problem. Overall, our work systemically uncovers novel aspects of small and large transformers, and continues the study of how they plan and reason.
View details
Preview abstract
Deep and wide neural networks successfully fit very complex functions today, but dense models are starting to be prohibitively expensive for inference. To mitigate this, one promising direction is networks that activate a sparse subgraph of the network. The subgraph is chosen by a data-dependent routing function, enforcing a fixed mapping of inputs to subnetworks (e.g., the Mixture of Experts (MoE) paradigm in Switch Transformers). However, prior work is largely empirical, and while existing routing functions work well in practice, they do not lead to theoretical guarantees on approximation ability. We aim to provide a theoretical explanation for the power of sparse networks. As our first contribution, we present a formal model of data-dependent sparse networks that captures salient aspects of popular architectures. We then introduce a routing function based on locality sensitive hashing (LSH) that enables us to reason about how well sparse networks approximate target functions. After representing LSH-based sparse networks with our model, we prove that sparse networks can match the approximation power of dense networks on Lipschitz functions. Applying LSH on the input vectors means that the experts interpolate the target function in different subregions of the input space. To support our theory, we define various datasets based on Lipschitz target functions, and we show that sparse networks give a favorable trade-off between number of active units and approximation quality.
View details
One network fits all? Modular versus monolithic task formulations in neural networks
Abhimanyu Das
Atish Agarwala
Brendan Juba
Vatsal Sharan
ICLR 2021 (2021)
Preview abstract
Can deep learning solve multiple, very different tasks simultaneously? We investigate how the representations of the underlying tasks affect the ability of a single neural network to learn them jointly. We present theoretical and empirical findings that a single neural network is capable of simultaneously learning multiple tasks from a combined data set, for a variety of methods for representing tasks---for example, when the distinct tasks are represented by well-separated clusters or decision trees over some task-code attributes. Indeed, more strongly, we present a novel analysis that shows that families of simple programming-like constructs for the task codings are learnable by two-layer neural networks with standard training. We study more generally how the complexity of learning such combined tasks grows with the complexity of the task codes; we find that learning many tasks can be provably hard, even though the individual tasks are easy to learn. We provide empirical support for the usefulness of the learning bounds by training networks on clusters, decision trees, and SQL-style aggregation.
View details