Pre-training generalist agents using offline reinforcement learning
February 23, 2023
Posted by Aviral Kumar, Student Researcher, and Sergey Levine, Research Scientist, Google Research
Quick links
Reinforcement learning (RL) algorithms can learn skills to solve decision-making tasks like playing games, enabling robots to pick up objects, or even optimizing microchip designs. However, running RL algorithms in the real world requires expensive active data collection. Pre-training on diverse datasets has proven to enable data-efficient fine-tuning for individual downstream tasks in natural language processing (NLP) and vision problems. In the same way that BERT or GPT-3 models provide general-purpose initialization for NLP, large RL–pre-trained models could provide general-purpose initialization for decision-making. So, we ask the question: Can we enable similar pre-training to accelerate RL methods and create a general-purpose “backbone” for efficient RL across various tasks?
In “Offline Q-learning on Diverse Multi-Task Data Both Scales and Generalizes”, to be published at ICLR 2023, we discuss how we scaled offline RL, which can be used to train value functions on previously collected static datasets, to provide such a general pre-training method. We demonstrate that Scaled Q-Learning using a diverse dataset is sufficient to learn representations that facilitate rapid transfer to novel tasks and fast online learning on new variations of a task, improving significantly over existing representation learning approaches and even Transformer-based methods that use much larger models.
Scaled Q-learning: Multi-task pre-training with conservative Q-learning
To provide a general-purpose pre-training approach, offline RL needs to be scalable, allowing us to pre-train on data across different tasks and utilize expressive neural network models to acquire powerful pre-trained backbones, specialized to individual downstream tasks. We based our offline RL pre-training method on conservative Q-learning (CQL), a simple offline RL method that combines standard Q-learning updates with an additional regularizer that minimizes the value of unseen actions. With discrete actions, the CQL regularizer is equivalent to a standard cross-entropy loss, which is a simple, one-line modification on standard deep Q-learning. A few crucial design decisions made this possible:
- Neural network size: We found that multi-game Q-learning required large neural network architectures. While prior methods often used relatively shallow convolutional networks, we found that models as large as a ResNet 101 led to significant improvements over smaller models.
- Neural network architecture: To learn pre-trained backbones that are useful for new games, our final architecture uses a shared neural network backbone, with separate 1-layer heads outputting Q-values of each game. This design avoids interference between the games during pre-training, while still providing enough data sharing to learn a single shared representation. Our shared vision backbone also utilized a learned position embedding (akin to Transformer models) to keep track of spatial information in the game.
- Representational regularization: Recent work has observed that Q-learning tends to suffer from representational collapse issues, where even large neural networks can fail to learn effective representations. To counteract this issue, we leverage our prior work to normalize the last layer features of the shared part of the Q-network. Additionally, we utilized a categorical distributional RL loss for Q-learning, which is known to provide richer representations that improve downstream task performance.
The multi-task Atari benchmark
We evaluate our approach for scalable offline RL on a suite of Atari games, where the goal is to train a single RL agent to play a collection of games using heterogeneous data from low-quality (i.e., suboptimal) players, and then use the resulting network backbone to quickly learn new variations in pre-training games or completely new games. Training a single policy that can play many different Atari games is difficult enough even with standard online deep RL methods, as each game requires a different strategy and different representations. In the offline setting, some prior works, such as multi-game decision transformers, proposed to dispense with RL entirely, and instead utilize conditional imitation learning in an attempt to scale with large neural network architectures, such as transformers. However, in this work, we show that this kind of multi-game pre-training can be done effectively via RL by employing CQL in combination with a few careful design decisions, which we describe below.
Scalability on training games
We evaluate the Scaled Q-Learning method’s performance and scalability using two data compositions: (1) near optimal data, consisting of all the training data appearing in replay buffers of previous RL runs, and (2) low quality data, consisting of data from the first 20% of the trials in the replay buffer (i.e., only data from highly suboptimal policies). In our results below, we compare Scaled Q-Learning with an 80-million parameter model to multi-game decision transformers (DT) with either 40-million or 80-million parameter models, and a behavioral cloning (imitation learning) baseline (BC). We observe that Scaled Q-Learning is the only approach that improves over the offline data, attaining about 80% of human normalized performance.
Further, as shown below, Scaled Q-Learning improves in terms of performance, but it also enjoys favorable scaling properties: just as how the performance of pre-trained language and vision models improves as network sizes get bigger, enjoying what is typically referred as “power-law scaling”, we show that the performance of Scaled Q-learning enjoys similar scaling properties. While this may be unsurprising, this kind of scaling has been elusive in RL, with performance often deteriorating with larger model sizes. This suggests that Scaled Q-Learning in combination with the above design choices better unlocks the ability of offline RL to utilize large models.
Fine-tuning to new games and variations
To evaluate fine-tuning from this offline initialization, we consider two settings: (1) fine-tuning to a new, entirely unseen game with a small amount of offline data from that game, corresponding to 2M transitions of gameplay, and (2) fine-tuning to a new variant of the games with online interaction. The fine-tuning from offline gameplay data is illustrated below. Note that this condition is generally more favorable to imitation-style methods, Decision Transformer and behavioral cloning, since the offline data for the new games is of relatively high-quality. Nonetheless, we see that in most cases Scaled Q-learning improves over alternative approaches (80% on average), as well as dedicated representation learning methods, such as MAE or CPC, which only use the offline data to learn visual representations rather than value functions.
In the online setting, we see even larger improvements from pre-training with Scaled Q-learning. In this case, representation learning methods like MAE yield minimal improvement during online RL, whereas Scaled Q-Learning can successfully integrate prior knowledge about the pre-training games to significantly improve the final score after 20k online interaction steps.
These results demonstrate that pre-training generalist value function backbones with multi-task offline RL can significantly boost performance of RL on downstream tasks, both in offline and online mode. Note that these fine-tuning tasks are quite difficult: the various Atari games, and even variants of the same game, differ significantly in appearance and dynamics. For example, the target blocks in Breakout disappear in the variation of the game as shown below, making control difficult. However, the success of Scaled Q-learning, particularly as compared to visual representation learning techniques, such as MAE and CPC, suggests that the model is in fact learning some representation of the game dynamics, rather than merely providing better visual features.
Conclusion and takeaways
We presented Scaled Q-Learning, a pre-training method for scaled offline RL that builds on the CQL algorithm, and demonstrated how it enables efficient offline RL for multi-task training. This work made initial progress towards enabling more practical real-world training of RL agents as an alternative to costly and complex simulation-based pipelines or large-scale experiments. Perhaps in the long run, similar work will lead to generally capable pre-trained RL agents that develop broadly applicable exploration and interaction skills from large-scale offline pre-training. Validating these results on a broader range of more realistic tasks, in domains such as robotics (see some initial results) and NLP, is an important direction for future research. Offline RL pre-training has a lot of potential, and we expect that we will see many advances in this area in future work.
Acknowledgements
This work was done by Aviral Kumar, Rishabh Agarwal, Xinyang Geng, George Tucker, and Sergey Levine. Special thanks to Sherry Yang, Ofir Nachum, and Kuang-Huei Lee for help with the multi-game decision transformer codebase for evaluation and the multi-game Atari benchmark, and Tom Small for illustrations and animation.