triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to add a pow function in python.triton.language.core?

arelkeselbri opened this issue · comments

I tried to use pow operation in a triton.jitted function as:

output = x + y**3
                ^

However got AttributeError("'tensor' object has no attribute '__pow__'").

In file python/triton/language/core.py, lin 279, although I found in class constexpr, a definition def __pow__(self, other), in class tensor (line 728), there is no method defining something like def pow(self) -> tensor:

Is there any special reason for not having a pow for tensor? Where should I look to include it?

In triton/python/tutorials/07-extern-functions.py, there is an example, which modified to pow becomes:

from triton.language.extra import libdevice

@triton.jit
def pow_kernel(
    x_ptr,
    y_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    x = libdevice.pow(x,x+1)
    tl.store(y_ptr + offsets, x, mask=mask)

x = torch.tensor([0.0, 1.1, 2.0, 3.0], dtype=torch.float32, device='cuda')
y = torch.zeros_like(x)

pow_kernel[(1,)](x, y, 4, BLOCK_SIZE=4)
print(y)

Output:

tensor([ 0.0000, 1.2216, 8.0000, 81.0000], device='cuda:0')