triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

triton cache does not invalidate cache correctly when dynamically choosing a function to call

christopherhesse opened this issue · comments

Run the following program, then fix the add function, then run it again and the result will still be wrong until you clear the triton cache.

import torch

import triton
import triton.language as tl

@triton.jit
def add(x, y):
    return x * y

@triton.jit
def sub(x, y):
    return x - y

@triton.jit
def binary_kernel(x_ptr,
               y_ptr,
               fn_name: tl.constexpr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr,
               ):
    
    if fn_name == "add":
        FN: tl.constexpr = add
    elif fn_name == "sub":
        FN: tl.constexpr = sub
    else:
        tl.static_assert(False, f"Invalid {fn_name=}")
        
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = FN(x, y)
    tl.store(output_ptr + offsets, output, mask=mask)

def binary(x: torch.Tensor, y: torch.Tensor, fn_name: str):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    binary_kernel[grid](x, y, fn_name, output, n_elements, BLOCK_SIZE=1024)
    return output

def main():
    torch.manual_seed(0)
    size = 1
    x = torch.rand(size, device='cuda')
    y = torch.rand(size, device='cuda')
    output_torch = x + y
    output_triton = binary(x, y, "add")
    print("torch", output_torch)
    print("triton", output_triton)


if __name__ == "__main__":
    main()

Hm, this is a bad bug. Thank you, I'll see if I can have a look.

I was not able to reproduce the bug on triton 2.3.0 and 2.3.1. Which version are you using? I am happy to help on this cc @jlebar

I can reproduce at HEAD.

@ByronHsu if you want to take this, that works for me! But please lmk because I'd like to make sure this gets fixed one way or another, especially if it's a regression.

I will try to take a stab today and tomorrow, and I will get back to you. Thank you! @jlebar

It's because of using constexpr as function. You can do this to bypass the problem:

    if fn_name == "add":
        output = add(x, y)

@Jokeren any guidance of how to fix this issue? I haven't had experience in triton code base (mainly just a user), but i am eager to learn!

You can take a look at DependenciesFinder

@Jokeren thank you! I am taking a look

Update 1: I am able to reproduce on the HEAD