Mechanics of Next Token Prediction with Transformers
Abstract
Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this seemingly simple training objective, they have revolutionized natural language processing within a short timeframe. Underlying this success is the self-attention mechanism. In this work, we ask: What does 1-layer self-attention learn from next-token prediction? We show that when trained with gradient descent, self-attention implements a simple automaton that induces a token hierarchy induced by the training data. Concretely, from the (sequence, label) pairs of the training data, we construct directed next-token graphs (NTGs) of the dataset that capture (input token, label) relations. We find that implicit bias of self-attention is captured by the strongly-connected components (SCCs) which partitions the NTGs into cyclic and acyclic subgraphs: Acyclic subgraph results in an SVM direction that enforces the priority order among SCCs. Cyclic subgraph yields a correction term that allocates the nonzero softmax probabilities among tokens within the same SCC. We empirically and theoretically demonstrate that superposition of these components can accurately predict the implicit bias of gradient descent in next-token prediction. We believe these results shed light on self-attention's ability to process sequential data and pave the path towards demystifying more complex transformer architectures.