Implement `jax.scipy.stats.bootstrap`
JeppeKlitgaard opened this issue · comments
I think it would be super neat to have a JAX version of the scipy.stats.bootstrap
function: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.bootstrap.html
Hi @JeppeKlitgaard
I would love to work on this ticket but I am a beginner contributor. Would you mind giving me some guidance to get started?
As I looked through the related source code in scipy, my plan is to break down the ticket into 2 steps, see if you think it makes sense:
- implement helper functions
_bootstrap_resample
,_bca_interval
and_percentile_along_axis
in JAX (scipy.stats.bootstrap
internally makes use of them) - together with (1), I can translate the main logic of
scipy.stats.bootstrap
into JAX(not sure if it can be straight-forwardly translated into JAX atm, constrained by its statical compilation property)
Additionally, I have a few questions:
scipy.stats.bootstrap
returnsBootstrapResult
object, should I follows this practice in our JAX implementation?- refering to
scipy.stats.boostrap
doc, the return has a variable size depending on the shape of input data. Could JAX work in this way?
cc @jakevdp as well to see if you could kindly shed me some light to get started
Hi - that sounds like a good plan for implementation. Regarding BootstrapResult
– I think it's just a simple dataclass or named tuple, you could create a similar container for the results in a JAX implementation, I'd lean toward NamedTuple because you get pytree flattening for free.
Regarding the return shape: it looks like it only depends on static properties of the inputs, so it should be fine for JAX.
Would be super neat if @riven314 would implement this. Otherwise I have a WIP that I will probably get around to finishing eventually.
@JeppeKlitgaard
I am working on this
if you have some WIP code for this ticket, would also be great to share this with me!
@riven314 I never got to a point that would be of much use to you, I'm afraid. I dislike the the way the original scipy
code is written to such an extent that I wouldn't want to mirror that implementation, but rather see if it is possible to get the same API using a different approach behind the scenes.
hi @jakevdp @JeppeKlitgaard
I have made a draft in my fork repo but I found that I couldn't raise any PR to this repo
so I just post my draft here: main...riven314:impl-jax-bootstrap
Would you mind helping me to do a prior review on my draft to see if it is in the right direction? (e.g. approach, coding style)
Pending work for me is to do inputs validation and unit tests. I also left a few questions in the code.
Hi @jakevdp @JeppeKlitgaard
I created a PR now, though it's still WIP (pending inputs validation, unit test)
is it ok to give me a prior review to see if my direction looks good?
**Updated PR: #10871
Any update on this?