google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jnp.argsort much slower than the numpy version

fbartolic opened this issue · comments

Here's a comparison of the JAX and numpy versions of argsort on a CPU:

import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')

key = random.PRNGKey(42)
key, subkey = random.split(key)

x_jnp = random.uniform(subkey, (100, 10000))
x_np = np.array(x_jnp)

%%timeit
np.argsort(x_np, axis=0)

%%timeit
jnp.argsort(x_jnp, axis=0).block_until_ready()

In this case jnp.argsort is ~5X slower than than np.argsort. I'm seeing >20x difference with more realistic arrays. Why is there such a large difference in performance between the two implementations?

You might find this FAQ helpful: FAQ: Is JAX Faster Than NumPy?.

Thanks! I read the FAQ but I didn't expect that that the difference in performance can get so large.

@jakevdp It seems that it is a pure computational efficiency problem of sort primitive on CPU.
I find that the sort primitive performance on GPU is satisfactory, and sort primitive share the same translation rule mlir lowering on all platform. Maybe XLA use a parallelism friendly sort algorithm which is inefficient on CPU.

import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')

key = random.PRNGKey(42)
key, subkey = random.split(key)

x_jnp = random.uniform(subkey, (1000000,))
x_np = np.array(x_jnp)

jnp.argsort(x_jnp, axis=0).block_until_ready() # compile
jnp.sort(x_jnp, axis=0).block_until_ready() # compile
from timeit import timeit
print(timeit('np.argsort(x_np, axis=0)', globals=globals(), number=10)) # 1.1s
print(timeit('jnp.argsort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 4.2s
print(timeit('jnp.sort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 3.7s

Yes, in general the XLA project has put much less effort into optimizing operations on CPU than on other backends.

I also note that the slowness is specific to floating-point values. Sorting int32 values is significantly faster. The only difference between the two as far as I can tell is the comparison function.

Running into the same issue, I created a workaround, where argsort is run under Numpy if there is only the CPU.

https://gist.github.com/sjdv1982/803695055c78b62e5d5dc92a004efa77

It seems to be compatible with jax.grad, but only after disabling a certain assertion in the JAX code.

I am a beginner in JAX, criticism is welcome, use with care.

That's a nice solution! To make it as compatible as possible with JAX transformations, I'd suggest doing the call to numpy via pure_callback instead.

Thank you! I didn't know pure_callback, I have updated the gist as you suggested. It runs under unmodified JAX now.

I am glad to see that when calling jax.value_and_grads, there are three identical calls into the function, but JAX is smart enough to coalesce them into one.