google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

unjitted jnp.where evaluates both branches

GJBoth opened this issue · comments

The following function unexpectedly fails due to division by zero, while according to the documentation it should return 0:

def f(x):
    print(jnp.abs(x) < 1.)
    return jnp.where(jnp.abs(x) < 1., 1. / (x - 1.), 0.)
f(1.) # prints false, then fails with division by zero

Interestingly enough, the jitted version does run correctly:

@jax.jit
def f_jit(x):
    return jnp.where(jnp.abs(x) < 1., 1. / (x - 1.), 0.)
f(1.) # returns 0

It seems that the unjitted jnp.where evaluates both branches? Is this expected behaviour - if so should a warning be added to the docs?

Thanks for the question! Yes, it's expected behavior that jnp.where evaluates both branches. Actually the jitted function does too. The issue here is that without jax.jit the division is just Python builtin floating point division, not involving JAX at all, and so that's causing the error.

To involve JAX in the version without jit, try this:

import jax
import jax.numpy as jnp

def f(x):
    print(jnp.abs(x) < 1.)
    return jnp.where(jnp.abs(x) < 1., 1. / (x - 1.), 0.)
f(jnp.array(1.))  # notice `jnp.array(1.)` instead of `1.`

jax.numpy.where is just like numpy.where, and it's just an ordinary function in Python, where arguments are evaluated before function application.