triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to perform a store operation on a part of a Tensor?

YKTian-x2b opened this issue · comments

Store operation on a part of a Tensor, what I want:

accum_outs = tl.zeros([N], dtype=tl.float32)
for col_off in range(0, N, BLOCK_SIZE):
    cols = col_off + tl.arange(0, BLOCK_SIZE)
    mask = cols < N
    a_eles = tl.load(a_ptr + cols, mask=mask, other=0.0)
    b_eles = tl.load(b_ptr + cols, mask=mask, other=0.0)
    # How to implement the following: 
    accum_outs[col_off: col_off+BLOCK_SIZE] = a_eles * b_eles

Because it's going to use "accum_outs" later and I don't want to store it back like following:

for col_off in range(0, N, BLOCK_SIZE):
    a_eles = tl.load(a_ptr + cols, mask=mask, other=0.0)
    b_eles = tl.load(b_ptr + cols, mask=mask, other=0.0)
    tl.store(accum_res_ptr + cols, a_eles * b_eles, mask=mask)

# tl.ops that must be done outside the loop

for col_off in range(0, N, BLOCK_SIZE):
    eles = tl.load(accum_res_ptr + cols, mask=mask, other=0.0)
    ...

At the moment you can't perform stores on part of tensors as far as I know. You can load a and b using 2d indexing and unravelling the output something like this for your use-case. NUM_BLOCKS should be passed as a tl.constexpr and should be the next power of 2 for (N + BLOCK_SIZE - 1//BLOCK_SIZE)

    num_blocks = tl.cdiv(N, BLOCK_SIZE)
    block_offs = tl.arange(0, NUM_BLOCKS)
    per_block_offs = tl.arange(0, BLOCK_SIZE)
    all_offs = block_offs[None, :] * BLOCK_SIZE + per_block_offs[:, None]
    
    # 2d pointers of shape [NUM_BLOCKS, BLOCK_SIZE]
    a = tl.load(a_ptr + all_offs,
                mask=all_offs < N and block_offs[None, :] < num_blocks,
                other=0.0)
    b = tl.load(b_ptr + all_offs,
                mask=all_offs  < N and block_offs[None, :] < num_blocks,
                other=0.0)
  
    # [NUM_BLOCKS, BLOCK_SIZE] --> NUM_BLOCKS * BLOCK_SIZE in which N elements will be filled others = 0
    accum_outs = tl.ravel(a * b)