Jump to Content

Effect Handlers and Choice-Based Learning in JAX

Shangyin Tan
Ningning Xie
NeurIPS Machine Learning for Systems Workshop (2023)

Abstract

Choice-based learning is a programming paradigm for expressing learning system in terms of choices and losses. We explore a practical implementation of choice-based learning in JAX by combining two techniques in a novel way: algebraic effects and the selection monad. We describe the design and implementation of our library, explore its usefulness for real-world applications like hyperparameter tuning and deep reinforcement learning, and compare it with existing approaches.