google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implementing a custom lifted transformation: jax.experimental.ode.odeint

jatentaki opened this issue · comments

Description of the model to be implemented

My request regards not a model but rather an example of how to lift new transformations. There are design docs on lifted transformations and they explain the rationale behind lifting, but fall short of explaining how to go about implementing one "in the wild" (or perhaps I fall short of understanding them). My specific scenario requires jax.experimental.ode.odeint in Module.__call__ and I have not managed the understand and translate the implementation of lifted vmap and others to odeint. For starters, I don't fully understand which is the transformation that makes this operation illegal: is it about the custom vjp? I see that there exists a lifted custom_vjp but I don't know how it is related to the problem I am facing: should I re-wrap _odeint and redefine the derivatives myself?

I am suggesting the case of odeint to be the basis for an expansion of the documentation, but I would be entirely satisfied to receive "one-off" help (which I can later try to turn into a tutorial to contribute back).

A minimal example

Here I created a colab with a minimal example of what I would like to get working.

A rationale for the example

This project implements a continuous normalizing flow (CNF) in flax. It defines the integrand as nn.Module and then sidesteps the issue of not being able to integrate within Module by integrating "manually" in loss_fn and solve_dynamics (thus effectively duplicating code). Myself, I am looking to use a CNF as a part of a bigger model (a conditional CNF, which "owns" both the CNF and the model for generating the conditional embeddings) and the approach of calling odeint manually each time I want to call my CCNF quickly becomes very burdensome.

Note

There are more issues with the CNF code than just this, for example the author was apparently not aware of the Module.apply(method=) keyword, thus the complexity of CNF and Neg_CNF classes. I am probably going to make a PR with cleanup in that regard, but I consider that issue completely orthogonal to my example request here.

@jatentaki I would be interested in helping to work on this. Looks useful/interesting!

@pharringtonp19 great, what kind of input are you expecting from me at this point?

@jatentaki You mentioned that you might make a tutorial on this. I think it would be helpful to have more examples/tutorials involving odes. So if you need help with this, I would be happy to lend a hand

I see, what I meant is that if I get help in an "unpolished" way (as in some code with comments, a collection of pointers, etc) I can then transform it into something more presentable and doc-quality. That said, right now I'm completely at loss with how to approach this problem.

For starters, I would like to ask

  1. Why is the exception being thrown in the first place? Is it because of odeint being wrapped in custom_jvp? odeint is not a transformation in the sense that vmap or jit are: it doesn't provide it's own jax interpreter for inner expressions, if I understand correctly. I think it's a bit of a different beast than the examples available.
  2. At what level of abstraction should I work to solve this? Should I be reading linen.transforms or core.lift?
  3. Can you provide the solution in some kind of pseudocode? Which of the already implemented transformations would look most similar to what I am trying to achieve here?

Hey @jatentaki - so sorry for the delay:

There's a lot of outstanding work 1) getting lifted versions of all the JAX AD machinery for Modules and/or 2) making some simpler tools for allowing end-users to "lift" custom functional recipes without having to dive into all the esoteric bookkeeping of our internal lifting machinery. So our apologies that this wasn't just a finished thing already, it's one of my top priorities.

For your particular use case, it seems that you're just trying to integrate a Module for a given, fixed set of parameters?

@levskaya thanks for your answer :) I'm not sure what you mean by "a fixed set" of parameters. It's fixed in terms of pytree shape, nothing dynamic, but their value will be subject to updating via gradient descent, if that's your question. To the best of my understanding, the minimal example I provided should generalize to whatever actual use cases I may have.

@jatentaki - sorry for the delay!
@jheek - you may be interested in this too.

I think I have a reasonable draft of a "lifted" version of odeint here in a fork of your colab notebook:
https://colab.research.google.com/drive/1WEg4xb2xbFU4KuiGknp52JztvIpWLDis?usp=sharing

I tried documenting some of the ideas to make things -slightly- less opaque and demonstrated how to use an "nn.odeint" on your toy example.

Let me know if this looks useful or if you have other questions, we could probably get a version of this checked-in soon if it looks like the right thing.

Sorry that it took so long to get to this, things have been unusually busy!

Hello again, thanks for labeling this as high priority and sorry for the delay - this is my side project and it was busy for me as well :( I just tested the code and it worked with my usecase, although I had to do some adaptation. My remarks are as follows:

  1. In initialization mode it returns directly the output of (a single evaluation of) fn, whereas during runtime it would prepend a leading dimension equal to t.shape[0]. This doesn't matter in our toy example since we don't do anything more with the output (integration is the final operation of this model), but in general this will break code downstream from the nn_odeint. A simple solution is to just expand the output like return y[None].repeat(t.shape[0], axis=0), repack_fn(scope); I don't see any potential downside to this.
  2. I have unfortunately underestimated the complexity of my actual use case and had to adapt the code to make it work. I am integrating a conditional (curried) function of signature dx_dt(x, t, ctx) where I want the two leading arguments to vary across invocations inside odeint while ctx is frozen. In my original code I was just calling odeint on deriv = lambda x, t: dx_dt(x, t, ctx) where ctx was captured by the closure. This pattern doesn't work with nn_odeint since it assumes to decorate either a class or a method in class definition. Unfortunately, misusing it on a lambda was passing this test, resulting in flax trying to treat the closure as a method definition, leading to unhelpful stack traces. Only by trying various variations on the code, hit-or-miss style, I managed to get these checks to fail and saw this crucial comment, which put me on the right path. I am not sure if it's possible to better distinguish those cases of misuse?
  3. The problem of providing the integral with ctx persisted. I ended up adapting your code to simply unpack one extra value x, t, ctx = args and manually pass it around inside core_odeint.inner. This is however very usecase-specific. Is it possible to make nn_odeint more general and work correctly in situations like y = nn_odeint(partial(self.submodule.some_method, ctx=ctx))(x, t_range)?
  4. The fact that I was able to understand the issues and adapt your code to my usecase within ~2h speaks a lot about its clarity, thanks a lot for your effort!

@levskaya Is there a simpler way to build an "ode-layer" in flax?