Rax: Composable Learning-to-Rank using JAX

Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (2022), 3051–3060

Abstract

Rax is a library for composable Learning-to-Rank (LTR) written entirely in JAX. The goal of Rax is to facilitate easy prototyping of LTR systems by leveraging the flexibility and simplicity of JAX. Rax provides a diverse set of popular ranking metrics and losses that integrate well with the rest of the JAX ecosystem. Furthermore, Rax implements a system of ranking-specific function transformations which allows fine-grained customization of ranking losses and metrics. Most notably Rax provides approx_t12n: a function transformation (t12n) that can transform any of our ranking metrics into an approximate and differentiable form that can be optimized. This provides a systematic way to directly optimize neural ranking models for ranking metrics that are not easily optimizable in other libraries. We empirically demonstrate the effectiveness of Rax by benchmarking neural models implemented using Flax and trained using Rax on two popular LTR benchmarks: WEB30K and Istella. Furthermore, we show that integrating ranking losses with T5, a large language model, can improve overall ranking performance on the MS MARCO passage ranking task. We are sharing the Rax library with the open source community as part of the larger JAX ecosystem at https://github.com/google/rax.