Support masking in atomic_cas?
bailuding opened this issue · comments
Bailu Ding commented
Hi team,
It seems that the atomic_cas API does not take a mask as compared to atomic_xchg. We observe that atomic_cas can be 20x slower than atomic_xchg when the mask is selective. From the code of atomic_xchg (see below), the masking is done outside of the core atomic API. So it looks that atomic_cas can do something similar to support masking without fundamental change to the function. Could you add the support for masking in atomic_cas?
Here is the code snippet for atomic_xchg and atomic_cas: https://github.com/triton-lang/triton/blob/main/python/triton/language/semantic.py
def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
Best,
Bailu