Large-Scale Differentially Private BERT
Abstract
In this work, we study the large-scale pretraining of BERT-Large~\citep{devlin2018bert} with differentially private SGD (DP-SGD). We show that combined with a careful implementation, scaling up the batch size to millions (i.e., mega-batches) improves the utility of the DP-SGD step for BERT; we also enhance the training efficiency by using an increasing batch size schedule. Our implementation builds on the recent work of \citet{subramani20}, who demonstrated that the overhead of a DP-SGD step is minimized with effective use of JAX \cite{jax2018github, frostig2018compiling} primitives in conjunction with the XLA compiler \cite{xladocs}. Our implementation achieves a masked language model accuracy of 60.5\% at a batch size of 2M, for $\eps = 5$, which is a reasonable privacy setting. To put this number in perspective, non-private BERT models achieve an accuracy of $\sim$70\%.