triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tl.cumsum seems emitting an internal error.

codewdy opened this issue · comments

# pip freeze | grep triton
triton==2.3.1
tritonclient==2.44.0
# pip freeze | grep torch 
torch==2.1.2+cu121
torchaudio==2.1.2+cu121
torchtyping==0.1.4
torchvision==0.16.2+cu121

reproduce code:

import torch
import triton
import triton.language as tl

@triton.jit
def block_count_kernel(i_ptr, o_ptr,
                       BS, K: tl.constexpr, EXPERTS: tl.constexpr, BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)
    offset_ins = tl.arange(0, BLOCK) + BLOCK * pid
    offset_k = tl.arange(0, K)
    offset_e = tl.arange(0, EXPERTS)
    i_block = tl.load(i_ptr + (offset_ins[:, None] * K + offset_k[None, :]))
    i_block = tl.reshape(i_block, [K * BLOCK, 1])
    e_block = tl.reshape(tl.arange(0, EXPERTS), [1, EXPERTS])
    mask = (i_block == e_block).to(tl.int32)
    # tl.device_print("mask", mask)
    mask = tl.sum(tl.reshape(mask, [BLOCK, K, EXPERTS]), axis=1)
    mask = tl.cumsum(mask, 0)
    tl.store(o_ptr + offset_ins[:, None] * EXPERTS + offset_e[None, :], mask)

block_count_impl = block_count_kernel[lambda meta: (meta["BS"] // meta["BLOCK"], )]

def block_count(i, experts, block):
    bs, k = i.shape
    output = torch.empty([bs, experts], dtype = i.dtype, device = i.device)
    block_count_impl(i, output, bs, k, experts, block)
    return output


torch.manual_seed(0)
x=torch.randint(16, [128 * 512, 4], device="cuda:0")
# x=torch.randint(16, [1, 4], device="cuda:0")
torch.cuda.synchronize()

mask = block_count(x, 16, 64)

if I remove mask = tl.cumsum(mask, 0), all works well. If I add, It report this error(add some return for readability):

python3.10: /root/.triton/llvm/llvm-5e5a22ca-centos-x64/include/llvm/Support/Casting.h:566: 
decltype(auto) llvm::cast(const From&) 
[with To = mlir::triton::gpu::BlockedEncodingAttr; From = mlir::Attribute]: 
Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

I tried a lot of kernels used tl.cumsum, all of them report this error.

I believe this is fixed in latest, try nightly