Need to add a "lifted while_loop"
avital opened this issue · comments
Originally requested by @sharadmv: "Is there a flax.while_loop
like there is scan
?"
@levskaya responded with "not yet, there should be. or something more general that doesn't assume what you're trying to do..."
In particular, this would be useful/necessary if people use Module.sow
for logging/printing from deep within a jitted function.
sow in particular could be problematic because we cannot statically allocate an array for the outputs of a dynamic loop (e.g.: while)
@jheek -- is the problem that you'd get a different output shape thus a different compiled XLA program? Or maybe it's something else, could you eleborate?
There's just no way in JAX to create a dynamically sized array at the moment. Effectively you want to turn the following into XLA:
array = []
while some_condition(...):
array.append(...)
with scan we know the number iterations so we can preallocate an array with a bunch of zeros and update a slice of it in each iteration which is what scan does under the hood.