ddrous / nodepint

Parallel-in-time integration of Neural ODEs with reduced basis approximation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NodePinT

Parallel-in-time integration of Neural ODEs with reduced basis approximation.

Features

  • Neural ordinary differential equations for a wide family of problems
  • Time-parallelisation (or layer parallelisation if compared to ResNets)
  • Reduced order modelling as a by-product of the the training process

Benchmarks

Dataset TrainingTime (Epochs) TrainingAcc TestAcc PeakMemory NetworkSize Language Hyperparams/Script
MNIST 4 mins 30 secs (420) 98.172% 97.117% 10.657 GB - Julia Link
16 mins 53 secs (420) 99.39% 2.06 GiB Jax Link
20 mins 33 secs (5 epochs !) 89.48% 5.89 GiB NodePinT Link
SPH 3.5 hours 92% 88% 6 GB 4 layers NodePinT hyperparameters4.json

Getting started

pip install nodepint

Flowcharts

Click to expand/collapse
NodePinT General Logic Encode-Process-Decode Logic
Flowchart Logic

ToDos

Click to expand/collapse
  • Massive parallelisation by combining time with data
  • Stochastic ODEs and diffusion models time parallelisation
  • Parallelism accross the projection absis. Since the PinT problem is projected on a lower dimensional space, this allows for further parallelism (2nd level):
    • If we always project on a randomly sampled 1D space, we can also solve the ODEs in parallel and average the NN weights without having to re(jit)compile
    • If we project on an increasingly bigger randomly sampled spaces, we can still do it, but we need to re(jit)compile the PinT processes
    • If we construct our space in a deterministic way (via sensitivity analysis wrt the latest added vector), then we can't parallelise since it will be sequential.
  • For optimal control (OC) while training a neural ODE, we propose several combinations:
    • DP for training, and DAL for OC
    • DP both for training and OC (a bit like PINN)
    • Same as above, but DAL for training

Dependencies

  • JAX
  • Equinox
  • Datasets (the "jax" extra)
  • Julia bindings to Python for SPH (all papers should benchmark NODEs on physical data upon us realeasing the SPHPC dataset)

About

Parallel-in-time integration of Neural ODEs with reduced basis approximation


Languages

Language:Python 98.8%Language:Shell 1.2%