google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[pallas] Interpreter mismatch for masked OOB indexing

oliverdutton opened this issue · comments

Description

For triton (if I have read this correctly) masked load/stores do not occur. So you can request to load/store to an index OOB for ref if that is masked. The current interpreter uses dynamic_slices/dynamic_slice_updates where masked updates are applied. In line with the 'always be in bounds' design in JAX if you index a slice that overruns the edge of the array it will be shifted to be valid (if possible). This leads to a disconnect in interpreter and Pallas outputs.

I know Triton is not Pallas, have you changed the desired behaviour for these cases in Pallas? - in which case this isn't a bug but needs documenting.

I've added a pull request fixing this with some tests #21144

Here is a colab minimal reproduction with shifts in load indices.

import jax
from jax import numpy as jnp, jit
from jax.experimental import pallas as pl

def masked_load_pallas_kernel(x_ref, o_ref):
  i = jnp.array(3)
  mask = jnp.arange(x_ref.shape[0]) + i < x_ref.shape[0]
  x = pl.load(x_ref, pl.dslice(i, mask.shape[0]), mask=mask, other=-1)
  o_ref[:] = x

@partial(jit, static_argnames=('interpret',))
def masked_load(x: jax.Array, interpret: bool=True):
  return pl.pallas_call(masked_load_pallas_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
                        interpret = interpret,
                        )(x)

x = jnp.arange(16)
print(f'Input:\nx:\n{x}\n\nOutput:')
for interpret in (True, False):
  print(f'Interpret: {interpret}\n{masked_load(x, interpret=interpret)}')
Input:
x:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Output:
Interpret: True
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 -1 -1 -1]
Interpret: False
[ 3  4  5  6  7  8  9 10 11 12 13 14 15 -1 -1 -1]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='b70fe499e42d', release='6.1.58+', version='#1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023', machine='x86_64')


$ nvidia-smi
Thu May  9 08:55:06 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   63C    P0              30W /  72W |  17235MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+