XLA errors with splash attention mask functions that use modulo
fding opened this issue · comments
Description
I am trying to use splash attention with a custom ComputableMask that uses the modulo operation. My first attempt is
class ModuloMask(splash_attention_mask._ComputableMask):
def __init__(
self,
input_size:int,
shard_count: int = 1,
):
def mask_function(q_ids, kv_ids):
q_ids = q_ids % 48
kv_ids = kv_ids % 48
return q_ids <= kv_ids
super().__init__(
shape=(input_size, input_size),
mask_function=mask_function,
shard_count=shard_count,
)
but this gives me a LoweringException: LoweringException: Exception while lowering eqn: a:bool[128,128] = ne b c
and NotImplementedError: Mixed dtype operands in cmp
I modified this to
class ModuloMask(splash_attention_mask._ComputableMask):
def __init__(
self,
input_size:int,
shard_count: int = 1,
):
width_np = cast(48, np.uint32)
def mask_function(q_ids, kv_ids):
if isinstance(q_ids, np.ndarray):
q_ids = q_ids % 48
kv_ids = kv_ids % 48
return q_ids <= kv_ids
else:
q_ids = jnp.mod(q_ids, width_np)
kv_ids = jnp.mod(kv_ids, width_np)
return q_ids <= kv_ids
super().__init__(
shape=(input_size, input_size),
mask_function=mask_function,
shard_count=shard_count,
)
but then I get ValueError: safe_map() argument 3 is shorter than argument 1
in ~/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py:418
(corresponding to ~/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py:560
).
If I get rid of the modulo and just return return q_ids <= kv_ids
in the mask function, it works.
The full source to reproduce is:
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
import numpy as np
def cast(x, dtype):
return (x * jnp.ones([])).astype(dtype)
class ModuloMask(splash_attention_mask._ComputableMask):
def __init__(
self,
input_size:int,
shard_count: int = 1,
):
width_np = cast(48, np.uint32)
def mask_function(q_ids, kv_ids):
if isinstance(q_ids, np.ndarray):
q_ids = q_ids % 48
kv_ids = kv_ids % 48
return q_ids <= kv_ids
else:
q_ids = jnp.mod(q_ids, width_np)
kv_ids = jnp.mod(kv_ids, width_np)
return q_ids <= kv_ids
super().__init__(
shape=(input_size, input_size),
mask_function=mask_function,
shard_count=shard_count,
)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return (
self.shape == other.shape
and np.array_equal(self.q_sequence, other.q_sequence)
)
def __hash__(self):
return hash(
(
type(self),
self.shape,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
)
)
B, H, N, D = 4, 4, 768, 256
q = k = v = jnp.zeros([B, H, N, D], jnp.bfloat16)
mask = ModuloMask(768)
masks = splash_attention_mask.MultiHeadMask(masks=[mask for i in range(q.shape[1])])
splash_kernel = splash_attention_kernel.make_splash_mha_single_device(
mask=masks
)
splash_attention_fn = jax.vmap(splash_kernel)
splash_attention_fn(
q, k, v
)
This bug occurs on jax==0.4.25
and also jax==0.4.28
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.0
python: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 9.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-d8febb2a-w-0', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')
Can you try using jax.lax.rem?
Yes this works, thanks!