Block Transformer: Global-To-Local Language Modeling for Fast Inference
Abstract
Transformer-based autoregressive language models apply self-attention to model global dependencies. Recent models scale this up to context windows of up to a million tokens, paving the way for retrieval augmented generation, long-form reasoning, autonomous agents, and more. However, self-attention carries quadratic compute and memory costs which hinders efficient scaling of context length. To tackle this, we question a de-facto design choice in transformers: does self-attention need to be performed across the global context, throughout the entire network? To this end, we propose the Block Transformer architecture, which first models global dependencies between blocks of tokens, and then decodes individual tokens within each block, isolating self-attention to the local context. The coarse treatment of global self-attention in lower layers can reduce its quadratic costs, and the locality of the upper layers can minimize memory usage to achieve high arithmetic intensity during inference. This reduces principal bottlenecks during inference to greatly improve wall-clock throughput in both the prefill and decode phases. Experiments show that our model achieves pareto-efficiency of throughput to performance compared to vanilla transformers, showing gains of 10--20$\times$ throughput across various model sizes spanning several orders of magnitude. The Block Transformer demonstrates the advantages of performing language modeling locally, enhancing performance without the overhead of global communication---a potential path to highly parallel and distributed language modeling.