PiKE: Adaptive Data Mixing for Large-Scale Multi-Task Learning Under Low Gradient Conflicts
Abstract
Modern foundation models are trained on diverse datasets to enhance generalization across tasks and domains. A central challenge in this process is determining how to effectively mix and sample data from multiple sources. This naturally leads to a multi-task learning (MTL) perspective. While prior work in MTL has emphasized mitigating gradient conflicts, we observe that large-scale pre-training scenarios-such as multilingual or multi-domain training-often exhibit little to no gradient conflict. Motivated by this observation, we propose PiKE (Positive gradient interaction-based K task weights Estimator), an adaptive data mixing algorithm that dynamically adjusts sampling weights during training. PiKE exploits non-conflicting gradient interactions to minimize a near-tight upper bound on the average loss decrease at each step, while incurring negligible computational overhead. We provide theoretical convergence guarantees and show that PiKE outperforms static and nonadaptive mixing baselines. Furthermore, we extend PiKE to promote balanced learning across tasks. Extensive experiments on large-scale language model pre-training confirm that PiKE achieves faster convergence and improved downstream performance compared to existing approaches.