Implementing a custom lifted transformation: jax.experimental.ode.odeint
jatentaki opened this issue · comments
Description of the model to be implemented
My request regards not a model but rather an example of how to lift new transformations. There are design docs on lifted transformations and they explain the rationale behind lifting, but fall short of explaining how to go about implementing one "in the wild" (or perhaps I fall short of understanding them). My specific scenario requires jax.experimental.ode.odeint
in Module.__call__
and I have not managed the understand and translate the implementation of lifted vmap
and others to odeint
. For starters, I don't fully understand which is the transformation that makes this operation illegal: is it about the custom vjp? I see that there exists a lifted custom_vjp
but I don't know how it is related to the problem I am facing: should I re-wrap _odeint
and redefine the derivatives myself?
I am suggesting the case of odeint
to be the basis for an expansion of the documentation, but I would be entirely satisfied to receive "one-off" help (which I can later try to turn into a tutorial to contribute back).
A minimal example
Here I created a colab with a minimal example of what I would like to get working.
A rationale for the example
This project implements a continuous normalizing flow (CNF) in flax. It defines the integrand as nn.Module
and then sidesteps the issue of not being able to integrate within Module
by integrating "manually" in loss_fn
and solve_dynamics
(thus effectively duplicating code). Myself, I am looking to use a CNF as a part of a bigger model (a conditional CNF, which "owns" both the CNF and the model for generating the conditional embeddings) and the approach of calling odeint
manually each time I want to call my CCNF quickly becomes very burdensome.
Note
There are more issues with the CNF code than just this, for example the author was apparently not aware of the Module.apply(method=)
keyword, thus the complexity of CNF
and Neg_CNF
classes. I am probably going to make a PR with cleanup in that regard, but I consider that issue completely orthogonal to my example request here.
@jatentaki I would be interested in helping to work on this. Looks useful/interesting!
@pharringtonp19 great, what kind of input are you expecting from me at this point?
@jatentaki You mentioned that you might make a tutorial on this. I think it would be helpful to have more examples/tutorials involving odes. So if you need help with this, I would be happy to lend a hand
I see, what I meant is that if I get help in an "unpolished" way (as in some code with comments, a collection of pointers, etc) I can then transform it into something more presentable and doc-quality. That said, right now I'm completely at loss with how to approach this problem.
For starters, I would like to ask
- Why is the exception being thrown in the first place? Is it because of odeint being wrapped in
custom_jvp
?odeint
is not a transformation in the sense thatvmap
orjit
are: it doesn't provide it's own jax interpreter for inner expressions, if I understand correctly. I think it's a bit of a different beast than the examples available. - At what level of abstraction should I work to solve this? Should I be reading
linen.transforms
orcore.lift
? - Can you provide the solution in some kind of pseudocode? Which of the already implemented transformations would look most similar to what I am trying to achieve here?
Hey @jatentaki - so sorry for the delay:
There's a lot of outstanding work 1) getting lifted versions of all the JAX AD machinery for Modules and/or 2) making some simpler tools for allowing end-users to "lift" custom functional recipes without having to dive into all the esoteric bookkeeping of our internal lifting machinery. So our apologies that this wasn't just a finished thing already, it's one of my top priorities.
For your particular use case, it seems that you're just trying to integrate a Module for a given, fixed set of parameters?
@levskaya thanks for your answer :) I'm not sure what you mean by "a fixed set" of parameters. It's fixed in terms of pytree shape, nothing dynamic, but their value will be subject to updating via gradient descent, if that's your question. To the best of my understanding, the minimal example I provided should generalize to whatever actual use cases I may have.
@jatentaki - sorry for the delay!
@jheek - you may be interested in this too.
I think I have a reasonable draft of a "lifted" version of odeint here in a fork of your colab notebook:
https://colab.research.google.com/drive/1WEg4xb2xbFU4KuiGknp52JztvIpWLDis?usp=sharing
I tried documenting some of the ideas to make things -slightly- less opaque and demonstrated how to use an "nn.odeint" on your toy example.
Let me know if this looks useful or if you have other questions, we could probably get a version of this checked-in soon if it looks like the right thing.
Sorry that it took so long to get to this, things have been unusually busy!
Hello again, thanks for labeling this as high priority and sorry for the delay - this is my side project and it was busy for me as well :( I just tested the code and it worked with my usecase, although I had to do some adaptation. My remarks are as follows:
- In initialization mode it returns directly the output of (a single evaluation of)
fn
, whereas during runtime it would prepend a leading dimension equal tot.shape[0]
. This doesn't matter in our toy example since we don't do anything more with the output (integration is the final operation of this model), but in general this will break code downstream from thenn_odeint
. A simple solution is to just expand the output likereturn y[None].repeat(t.shape[0], axis=0), repack_fn(scope)
; I don't see any potential downside to this. - I have unfortunately underestimated the complexity of my actual use case and had to adapt the code to make it work. I am integrating a conditional (curried) function of signature
dx_dt(x, t, ctx)
where I want the two leading arguments to vary across invocations insideodeint
whilectx
is frozen. In my original code I was just calling odeint onderiv = lambda x, t: dx_dt(x, t, ctx)
wherectx
was captured by the closure. This pattern doesn't work withnn_odeint
since it assumes to decorate either a class or a method in class definition. Unfortunately, misusing it on a lambda was passing this test, resulting in flax trying to treat the closure as a method definition, leading to unhelpful stack traces. Only by trying various variations on the code, hit-or-miss style, I managed to get these checks to fail and saw this crucial comment, which put me on the right path. I am not sure if it's possible to better distinguish those cases of misuse? - The problem of providing the integral with
ctx
persisted. I ended up adapting your code to simply unpack one extra valuex, t, ctx = args
and manually pass it around insidecore_odeint.inner
. This is however very usecase-specific. Is it possible to makenn_odeint
more general and work correctly in situations likey = nn_odeint(partial(self.submodule.some_method, ctx=ctx))(x, t_range)
? - The fact that I was able to understand the issues and adapt your code to my usecase within ~2h speaks a lot about its clarity, thanks a lot for your effort!
@levskaya Is there a simpler way to build an "ode-layer" in flax?