triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support masking in atomic_cas?

bailuding opened this issue · comments

Hi team,

It seems that the atomic_cas API does not take a mask as compared to atomic_xchg. We observe that atomic_cas can be 20x slower than atomic_xchg when the mask is selective. From the code of atomic_xchg (see below), the masking is done outside of the core atomic API. So it looks that atomic_cas can do something similar to support masking without fundamental change to the function. Could you add the support for masking in atomic_cas?

Here is the code snippet for atomic_xchg and atomic_cas: https://github.com/triton-lang/triton/blob/main/python/triton/language/semantic.py

def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
                builder: ir.builder) -> tl.tensor:
    ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
    sem = _str_to_sem(sem)
    scope = _str_to_scope(scope)
    return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
                     val.type)

def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
    sem = _str_to_sem(sem)
    scope = _str_to_scope(scope)
    element_ty = ptr.type.scalar.element_ty
    if element_ty.primitive_bitwidth not in [16, 32, 64]:
        raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
    return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)


Best,
Bailu