google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Need to add a "lifted while_loop"

avital opened this issue · comments

Originally requested by @sharadmv: "Is there a flax.while_loop like there is scan?"

@levskaya responded with "not yet, there should be. or something more general that doesn't assume what you're trying to do..."

In particular, this would be useful/necessary if people use Module.sow for logging/printing from deep within a jitted function.

commented

sow in particular could be problematic because we cannot statically allocate an array for the outputs of a dynamic loop (e.g.: while)

@jheek -- is the problem that you'd get a different output shape thus a different compiled XLA program? Or maybe it's something else, could you eleborate?

commented

There's just no way in JAX to create a dynamically sized array at the moment. Effectively you want to turn the following into XLA:

array = []
while some_condition(...):
  array.append(...)

with scan we know the number iterations so we can preallocate an array with a bunch of zeros and update a slice of it in each iteration which is what scan does under the hood.

commented

This was implemented in #1949