XLA: log-indexing should be optimised to indexing-log
ayaka14732 opened this issue · comments
Ayaka commented
log-indexing:
np.log(x)[y]
should be optimised to indexing-log:
np.log(x[y])
Test:
import jax
import jax.numpy as np
import jax.random as rand
key = rand.PRNGKey(42)
x = rand.truncated_normal(key, 1e-5, 5., (16384, 10)) # (16384, 10)
y = rand.randint(key, (16,), 0, 10) # (16,)
@jax.jit
def f1(x, y):
return np.log(x)[y]
@jax.jit
def f2(x, y):
return np.log(x[y])
a = f1(x, y)
b = f2(x, y)
assert np.array_equal(a, b)
import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=50000)) # 14.493794055997569
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=50000)) # 9.56337349399837
Generated HLO
log-indexing:
ENTRY %xla_computation_f1.15 (parameter.1: f32[16384,10], parameter.2: s32[16]) -> (f32[16,10]) {
%parameter.2 = s32[16] parameter(1)
%parameter.1 = f32[16384,10] parameter(0)
%fusion.1 = s32[16] fusion(s32[16] %parameter.2), calls=%fused_computation.1
%fusion.2 = s32[1024] fusion(s32[16] %fusion.1), calls=%fused_computation.2
%copy = f32[16384,10] copy(f32[16384,10] %parameter.1)
%log.4 = f32[16384,10] log(f32[16384,10] %copy)
%fusion = f32[16,10] fusion(f32[16384,10] %log.4, s32[1024] %fusion.2), calls=%fused_computation
ROOT %tuple.14 = (f32[16,10]) tuple(f32[16,10] %fusion)
}
indexing-log:
ENTRY %xla_computation_f2.15 (parameter.1: f32[16384,10], parameter.2: s32[16]) -> (f32[16,10]) {
%parameter.2 = s32[16] parameter(1)
%parameter.1 = f32[16384,10] parameter(0)
%fusion.1 = s32[16] fusion(s32[16] %parameter.2), calls=%fused_computation.1
%fusion.2 = s32[1024] fusion(s32[16] %fusion.1), calls=%fused_computation.2
%copy = f32[16384,10] copy(f32[16384,10] %parameter.1)
%fusion = f32[16,10] fusion(f32[16384,10] %copy, s32[1024] %fusion.2), calls=%fused_computation
%log.13 = f32[16,10] log(f32[16,10] %fusion)
ROOT %tuple.14 = (f32[16,10]) tuple(f32[16,10] %log.13)
}
You Jiacheng commented
When len(x) < len(y)
, log-indexing will be faster, but XLA has static shape...
Ayaka commented
When
len(x) < len(y)
, log-indexing will be faster, but XLA has static shape...
You are right. log-indexing should be optimised to indexing-log when len(x) > len(y)
.
Peter Hawkins commented
Thanks, I filed a bug with the XLA folks (Google bug b/229671708).