Understanding Self-Attention through Prompt-Conditioned Markov Chains
Abstract
Modern language models rely on the transformer architecture and self-attention mechanism to perform language understanding and text generation. In this work, we study learning a 1-layer self-attention model from a set of prompts and associated output data sampled according to ground-truth weights. As our main contribution, we establish a precise mapping between a self-attention model and a Markov chain through a convex problem formulation: Inputting a prompt to the model samples the output token according to a prompt-conditioned Markov chain which weights the transitions of a base chain. Additionally, incorporating positional encoding results in position-dependent scaling of the chain transitions. Building on this formalism, we develop identifiability/coverage conditions for data distribution that guarantee consistent estimation and establish sample complexity guarantees under IID sampled data. Finally, we study the challenging problem of learning from a single dependent trajectory generated from an initial prompt. Unlike standard Markov chains, we characterize a winner-takes-all phenomenon where the sampling process degenerates into generating a limited subset of tokens due to the non-mixing nature of the attention layer. We argue that this phenomenon explains the tendency of modern LLMs to generate repetitive text and makes consistent estimation from a single-trajectory intricate and problem-dependent -- which we provide a preliminary characterization of.