SOTAVerified

DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

2024-03-11Code Available0· sign in to hype

Keith Rush, Zachary Charles, Zachary Garrett, Sean Augenstein, Nicole Mitchell

Code Available — Be the first to reproduce this paper.

Reproduce

Code

Abstract

We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at https://github.com/google-research/google-research/tree/master/drjax.

Tasks

Reproductions