google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement `jax.scipy.stats.bootstrap`

JeppeKlitgaard opened this issue · comments

I think it would be super neat to have a JAX version of the scipy.stats.bootstrap function: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.bootstrap.html

Hi @JeppeKlitgaard
I would love to work on this ticket but I am a beginner contributor. Would you mind giving me some guidance to get started?

As I looked through the related source code in scipy, my plan is to break down the ticket into 2 steps, see if you think it makes sense:

  1. implement helper functions _bootstrap_resample, _bca_interval and _percentile_along_axis in JAX (scipy.stats.bootstrap internally makes use of them)
  2. together with (1), I can translate the main logic of scipy.stats.bootstrap into JAX(not sure if it can be straight-forwardly translated into JAX atm, constrained by its statical compilation property)

Additionally, I have a few questions:

  1. scipy.stats.bootstrap returns BootstrapResult object, should I follows this practice in our JAX implementation?
  2. refering to scipy.stats.boostrap doc, the return has a variable size depending on the shape of input data. Could JAX work in this way?

cc @jakevdp as well to see if you could kindly shed me some light to get started

Hi - that sounds like a good plan for implementation. Regarding BootstrapResult – I think it's just a simple dataclass or named tuple, you could create a similar container for the results in a JAX implementation, I'd lean toward NamedTuple because you get pytree flattening for free.

Regarding the return shape: it looks like it only depends on static properties of the inputs, so it should be fine for JAX.

Would be super neat if @riven314 would implement this. Otherwise I have a WIP that I will probably get around to finishing eventually.

@JeppeKlitgaard
I am working on this
if you have some WIP code for this ticket, would also be great to share this with me!

@riven314 I never got to a point that would be of much use to you, I'm afraid. I dislike the the way the original scipy code is written to such an extent that I wouldn't want to mirror that implementation, but rather see if it is possible to get the same API using a different approach behind the scenes.

hi @jakevdp @JeppeKlitgaard
I have made a draft in my fork repo but I found that I couldn't raise any PR to this repo
so I just post my draft here: main...riven314:impl-jax-bootstrap

Would you mind helping me to do a prior review on my draft to see if it is in the right direction? (e.g. approach, coding style)
Pending work for me is to do inputs validation and unit tests. I also left a few questions in the code.

Hi @jakevdp @JeppeKlitgaard
I created a PR now, though it's still WIP (pending inputs validation, unit test)
is it ok to give me a prior review to see if my direction looks good?

**Updated PR: #10871

Any update on this?