KeAWang / jax-t1d-simulator

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Try out the speedups from using Jax:

python simglucose/speed_test.py

On CPU I get:

JAX Time elapsed per ODE step: 1.2874603271484375e-05
Numpy Time elapsed per ODE step: 0.025090458393096923
Speedup factor: 1948.83
JAX Time elapsed per vmapped ODE step (100 patients): 0.0011016869544982911
Vmap speedup over sequential looping:  1.17

On a 1080ti GPU, I get:

JAX Time elapsed per ODE step: 4.86445426940918e-05
Numpy Time elapsed per ODE step: 0.04463413953781128
Speedup factor: 917.56
JAX Time elapsed per vmapped ODE step (100 patients): 0.0013061881065368653
Vmap speedup over sequential looping:  3.72

Thus GPU gives significant gains when vmapped over multiple patients as expected.

About


Languages

Language:Python 100.0%