google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

XLA errors with splash attention mask functions that use modulo

fding opened this issue · comments

Description

I am trying to use splash attention with a custom ComputableMask that uses the modulo operation. My first attempt is

class ModuloMask(splash_attention_mask._ComputableMask):
  def __init__(
      self,
      input_size:int,
      shard_count: int = 1,
  ):
    def mask_function(q_ids, kv_ids):
        q_ids = q_ids % 48
        kv_ids = kv_ids % 48
        return q_ids <= kv_ids 

    super().__init__(
        shape=(input_size, input_size),
        mask_function=mask_function,
        shard_count=shard_count,
    )

but this gives me a LoweringException: LoweringException: Exception while lowering eqn: a:bool[128,128] = ne b c and NotImplementedError: Mixed dtype operands in cmp

I modified this to

class ModuloMask(splash_attention_mask._ComputableMask):
  def __init__(
      self,
      input_size:int,
      shard_count: int = 1,
  ):
    width_np = cast(48, np.uint32)
    def mask_function(q_ids, kv_ids):
      if isinstance(q_ids, np.ndarray):
        q_ids = q_ids % 48
        kv_ids = kv_ids % 48
        return q_ids <= kv_ids
      else:
        q_ids = jnp.mod(q_ids, width_np)
        kv_ids = jnp.mod(kv_ids, width_np)
        return q_ids <= kv_ids

    super().__init__(
        shape=(input_size, input_size),
        mask_function=mask_function,
        shard_count=shard_count,
    )

but then I get ValueError: safe_map() argument 3 is shorter than argument 1 in ~/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py:418 (corresponding to ~/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py:560).
If I get rid of the modulo and just return return q_ids <= kv_ids in the mask function, it works.

The full source to reproduce is:

import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
import numpy as np

def cast(x, dtype):
  return (x * jnp.ones([])).astype(dtype)

class ModuloMask(splash_attention_mask._ComputableMask):
  def __init__(
      self,
      input_size:int,
      shard_count: int = 1,
  ):
    width_np = cast(48, np.uint32)
    def mask_function(q_ids, kv_ids):
      if isinstance(q_ids, np.ndarray):
        q_ids = q_ids % 48
        kv_ids = kv_ids % 48
        return q_ids <= kv_ids
      else:
        q_ids = jnp.mod(q_ids, width_np)
        kv_ids = jnp.mod(kv_ids, width_np)
        return q_ids <= kv_ids

    super().__init__(
        shape=(input_size, input_size),
        mask_function=mask_function,
        shard_count=shard_count,
    )

  def __eq__(self, other: object):
    if not isinstance(other, type(self)):
        return NotImplemented

    return (
        self.shape == other.shape
        and np.array_equal(self.q_sequence, other.q_sequence)
    )

  def __hash__(self):
    return hash(
        (
            type(self),
            self.shape,
            self.q_sequence.tobytes() if self.q_sequence is not None else None,
        )
    )

B, H, N, D = 4, 4, 768, 256
q = k = v = jnp.zeros([B, H, N, D], jnp.bfloat16)
mask = ModuloMask(768)

masks = splash_attention_mask.MultiHeadMask(masks=[mask for i in range(q.shape[1])])
splash_kernel = splash_attention_kernel.make_splash_mha_single_device(
    mask=masks
)
splash_attention_fn = jax.vmap(splash_kernel)
splash_attention_fn(
    q, k, v
)

This bug occurs on jax==0.4.25 and also jax==0.4.28

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.0
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 9.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-d8febb2a-w-0', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')

Can you try using jax.lax.rem?

Yes this works, thanks!