google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

XLA: log-indexing should be optimised to indexing-log

ayaka14732 opened this issue · comments

commented

log-indexing:

np.log(x)[y]

should be optimised to indexing-log:

np.log(x[y])

Test:

import jax
import jax.numpy as np
import jax.random as rand

key = rand.PRNGKey(42)

x = rand.truncated_normal(key, 1e-5, 5., (16384, 10))  # (16384, 10)
y = rand.randint(key, (16,), 0, 10)  # (16,)

@jax.jit
def f1(x, y):
    return np.log(x)[y]

@jax.jit
def f2(x, y):
    return np.log(x[y])

a = f1(x, y)
b = f2(x, y)
assert np.array_equal(a, b)

import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=50000))  # 14.493794055997569
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=50000))  # 9.56337349399837

Generated HLO

log-indexing:

ENTRY %xla_computation_f1.15 (parameter.1: f32[16384,10], parameter.2: s32[16]) -> (f32[16,10]) {
  %parameter.2 = s32[16] parameter(1)
  %parameter.1 = f32[16384,10] parameter(0)
  %fusion.1 = s32[16] fusion(s32[16] %parameter.2), calls=%fused_computation.1
  %fusion.2 = s32[1024] fusion(s32[16] %fusion.1), calls=%fused_computation.2
  %copy = f32[16384,10] copy(f32[16384,10] %parameter.1)
  %log.4 = f32[16384,10] log(f32[16384,10] %copy)
  %fusion = f32[16,10] fusion(f32[16384,10] %log.4, s32[1024] %fusion.2), calls=%fused_computation
  ROOT %tuple.14 = (f32[16,10]) tuple(f32[16,10] %fusion)
}

indexing-log:

ENTRY %xla_computation_f2.15 (parameter.1: f32[16384,10], parameter.2: s32[16]) -> (f32[16,10]) {
  %parameter.2 = s32[16] parameter(1)
  %parameter.1 = f32[16384,10] parameter(0)
  %fusion.1 = s32[16] fusion(s32[16] %parameter.2), calls=%fused_computation.1
  %fusion.2 = s32[1024] fusion(s32[16] %fusion.1), calls=%fused_computation.2
  %copy = f32[16384,10] copy(f32[16384,10] %parameter.1)
  %fusion = f32[16,10] fusion(f32[16384,10] %copy, s32[1024] %fusion.2), calls=%fused_computation
  %log.13 = f32[16,10] log(f32[16,10] %fusion)
  ROOT %tuple.14 = (f32[16,10]) tuple(f32[16,10] %log.13)
}

When len(x) < len(y), log-indexing will be faster, but XLA has static shape...

commented

When len(x) < len(y), log-indexing will be faster, but XLA has static shape...

You are right. log-indexing should be optimised to indexing-log when len(x) > len(y).

Thanks, I filed a bug with the XLA folks (Google bug b/229671708).