Breaking the Softmax Bottleneck via Learnable Monotonic Pointwise Non-linearities
Abstract
The softmax function on top of a final linear layer is the de facto method to output probability distributions in neural networks. In many applications such as language models or text generation, these models have to produce distributions over large output vocabularies. Recently, this has been shown to have limited representational capacity due to its connection with the rank bottleneck in matrix factorization. However, little is known about the limitations of linear-softmax for quantities of practical interest such as cross entropy or mode estimation, direction theoretically and empirically explored in this paper. As an efficient and effective solution to alleviate this issue, we propose to learn parametric monotonic functions on top of the logits. Theoretically, we show that such monotonic functions are likely to increase the rank of a matrix to its full rank. Empirically, our method improves over the traditional softmax-linear layer both in synthetic and real language model experiments with negligible time or memory overhead, while being comparable to the more computationally expensive mixture of softmax distributions.