- Dehao Chen
- Dmitry (Dima) Lepikhin
- HyoukJoong Lee
- Maxim Krikun
- Noam Shazeer
- Orhan Firat
- Yanping Huang
- Yuanzhong Xu
- Zhifeng Chen
Abstract
Neural network scaling has been critical for improving the model quality in many real world machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a sure-fire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. GShard is a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler. It provides an elegant way to express a wide range of parallel computation patterns with minimal changes of existing model code. It enabled us to scale up multilingual machine translation Transformer model with Sparsely-Gated Mixture-of-Experts beyond 600 billion parameters using automatic sharding. We demonstrate that such a giant model can be easily trained on 2048 TPU v3 accelerators in 4 days to achieve far superior quality for translation from 100 languages to English compared to the prior art.
Research Areas
Learn more about how we do research
We maintain a portfolio of research projects, providing individuals and teams the freedom to emphasize specific types of work