SOTAVerified

Scaling Deep Learning Training with MPMD Pipeline Parallelism

2024-12-18Unverified0· sign in to hype

Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

Unverified — Be the first to reproduce this paper.

Reproduce

Abstract

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to 1.11 with respect to the best performing SPMD configuration.

Tasks

Reproductions