Scalable Gradients for Stochastic Differential Equations
Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud
Code Available — Be the first to reproduce this paper.
ReproduceCode
- github.com/google-research/torchsdeOfficialIn paperpytorch★ 1,708
- github.com/xwinxu/bayesian-sdejax★ 173
- github.com/xwinxu/bayesdejax★ 173
- github.com/JFagin/latent_SDEpytorch★ 2
Abstract
The adjoint sensitivity method scalably computes gradients of solutions to ordinary differential equations. We generalize this method to stochastic differential equations, allowing time-efficient and constant-memory computation of gradients with high-order adaptive solvers. Specifically, we derive a stochastic differential equation whose solution is the gradient, a memory-efficient algorithm for caching noise, and conditions under which numerical solutions converge. In addition, we combine our method with gradient-based stochastic variational inference for latent stochastic differential equations. We use our method to fit stochastic dynamics defined by neural networks, achieving competitive performance on a 50-dimensional motion capture dataset.
Tasks
Benchmark Results
| Dataset | Model | Metric | Claimed | Verified | Status |
|---|---|---|---|---|---|
| CMU Mocap-2 | Latent SDE | Test Error | 4.03 | — | Unverified |
| CMU Mocap-2 | Latent ODE | Test Error | 5.98 | — | Unverified |