google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax.lax.custom_linear_solve / jax.lax.custom_root doesn't respect jax.disable_jit

YouJiacheng opened this issue · comments

After #9938
The document says:

For debugging it is useful to have a mechanism that disables jit() everywhere in a dynamic context. Note that this not only disables explicit uses of jit by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of body and cond functions passed to higher-level primitives like scan() and while_loop(), JIT used in implementations of jax.numpy functions, and any other case where jit is used within an API’s implementation.

However, when I worked for #9714 (comment), I found that there are jaxprs generated inside with jax.disable_jit():.
I checked the source of jax.lax.custom_linear_solve, and found it doesn't check config.jax_disable_jit.
jax.lax.custom_root has the same problem too.

Thanks for raising this. Maybe the docs should say it disables jit in many (but not all) places. Then we can treat it as an enhancement to support this feature in more cases.

@shoyer does it make sense to make custom_linear_solve / custom_root sensitive to disable_jit?

@YouJiacheng jaxprs may still be generated even with disable_jit supported everywhere because we use them both for jit-style staging of computations as well as reverse-mode AD.

Thanks for reply.
I think at least we can print a INFO or WARN at which jit isn't disabled. Because jax.disable_jit is intended for debugging, adding some outputs seems reasonable.

custom_linear_solve/custom_root trace to a JAXpr using the equivalent of jax.closure_convert. It isn't really tractable to evaluate JAXprs without JIT, because loops become extremely slow (way slower than pure Python).

It would be nice to rewrite these functions to avoid using closure conversion automatically, which would make them more flexible and make using disable JIT feasible. This might require tweaking the APIs slightly -- I haven't really thought about it yet.

@shoyer thanks for the input. Interesting about not doing closure conversion automatically; we could make them "final-style" primitives but that's pretty tricky.

@YouJiacheng you're right that we should at least log something rather than having silent surprising behavior.

If it's a helpful point of reference. Diffrax uses its own custom-jvp-through-root-finding code for nonlinear solves, as I didn't want the magic (and previously, bugs) of closure_convert-type behaviour. So this has an API that may serve as inspiration.

See here for the explicit commentary; more broadly the rest of the nonlinear_solver folder.

@patrick-kidger interesting, thanks for sharing! One important difference between your solver and lax.custom_root is that you calculate a full Jacobian and use jnp.linalg.solve, where lax.custom_root allows for a custom (possibly matrix-free) linear solver.

Absolutely. I'm not claiming that it's a perfectly equivalent API, just that it might make a useful point of reference as another take on this, in particular without closure_convert.
(FWIW I'll probably add support for other linear solvers at some point down the line, if/when that becomes necessary.)

Just wanted to chime in that I'm happy to take a look at these, I've had some fun with adding auxiliary arguments to custom_{root,linear_solve} in the past and would be happy to continue.

@nicholasjng Contributions here would be very welcome! I'm not 100% sure it will be possible to remove closure conversion custom_linear_solve because we currently need implement it with a JAX primitive (until #9129 implementing custom_transpose is finished, CC @froystig). It might be worth holding off any API changes until we can switch both custom_linear_solve and custom_root at once, but in any case starting to investigate how removing closure conversion would work would certainly be valuable.

I am trying to use PyAMG for solving my PDEs. I thought I could wrap the solver inside jax.lax.custom_linear_solve to ensure that I can also perform the adjoint solve.

However, it seems that JAX throws a TracerArrayConversionError as it tries to differentiate via the solver. However, documentation mentions that the solver need not be differentiable.

I am not sure if I am doing something wrong. I was hoping someone could kindly point me in the right direction. Thanks

A standalone simple implementation can be found here

You will need to use an external callback to wrap libraries for NumPy arrays like PyAMG with custom_linear_solve.

Thank you so much @shoyer ! This worked

For folks from the future who might be running into similar issue here is the implementation

I may still be missing something when it comes to implementing this efficiently, but have a working implementation!

Thank you once again!