Unexpected speedup from wrapping function call in trivial jax.lax.cond statement

TonyZhou729 opened this issue



We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.

In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.

import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.ndimage import map_coordinates
import time

Nx = 300
Ny = 100000
x_axis = jnp.linspace(5., 12.75, Nx)

def main():
    y_axis = jnp.linspace(0, 1, Ny)

    # Initial value of B is just (Nx, Ny) size arrays of zeros.
    B = jnp.zeros((Nx, Ny), dtype="float32")

    def loop_in_main(carry, i):
        B = carry
        y = y_axis[i]

        """ Obtain an array A using interp_A_from_B(), picking one of three ways """
        # Case 1: We simply run interp_A_from_B() every step
        A = interp_A_from_B((y, y_axis, B))

        # Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
        # interp_A_from_B since index i is always greater than -1.
        # For some reason we observe a speed up over case 1.
        #A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))

        # Update B array with values of A from this loop.
        B = set_B_to_A(i, B, A)

        return B, None

    # Use lax.scan to run loop and update B Ny times.
    # Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
    B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))

    return B

def interp_A_from_B(params):
    # B is a (Nx, Ny) array.
    # A is a (Nx,) array.

    y, y_axis, B    = params
    # Precise value of y to interpolate at.
    y_prime         = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
    # Convert to index position within y_axis, to use with ndimage.map_coordinates.
    y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
    # Interpolated version of A from B via 2D map_coordinates.
    interp          = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
    # Here, only use the interpolated result for values of y_prime larger than the smallest y in     y_axis.
    condition       = y_prime < y_axis[0]

    # Put A array together, with some fill in values for where we don't want the interpolated value.
    A               = condition * jnp.exp(-x_axis[:Nx-1]) \
                    + (1-condition) * interp
    A               = jnp.append(A, jnp.exp(-x_axis[-1]))

    return A

def set_B_to_A(i, B, A):
    # Update a column of B with the current value of A.
    B =[:, i].set(A)
    return B

def false_func(params):
    # Trivial false function, sets all entries of A to some fill values if called.
    A = jnp.exp(-x_axis)
    return A

""" Running main() a couple times to see the speed """
for i in range(5):
    s = time.time()
    B = main()
    print(time.time() - s)

When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)


But switching to case 2 we see


In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.

Thank you in advance for your help and comments!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.16
jaxlib: 0.4.16
numpy:  1.24.3
python: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1