MPAX: Mathematical Programming in JAX
2024-12-12Code Available2· sign in to hype
Haihao Lu, Zedong Peng, Jinwen Yang
Code Available — Be the first to reproduce this paper.
ReproduceCode
- github.com/mit-lu-lab/mpaxOfficialIn paperjax★ 148
Abstract
This paper presents MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating linear programming (LP) into machine learning workflows. MPAX implemented the state-of-the-art first-order methods, restarted average primal-dual hybrid gradient and reflected restarted Halpern primal-dual hybrid gradient, to solve LPs in JAX. This provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Extensive numerical experiments demonstrate the advantages of MPAX over existing solvers. The solver is available at https://github.com/MIT-Lu-Lab/MPAX.