How to add a pow function in python.triton.language.core?
arelkeselbri opened this issue · comments
I tried to use pow operation in a triton.jitted function as:
output = x + y**3
^
However got AttributeError("'tensor' object has no attribute '__pow__'")
.
In file python/triton/language/core.py
, lin 279, although I found in class constexpr
, a definition def __pow__(self, other)
, in class tensor
(line 728), there is no method defining something like def pow(self) -> tensor:
Is there any special reason for not having a pow for tensor? Where should I look to include it?
In triton/python/tutorials/07-extern-functions.py, there is an example, which modified to pow
becomes:
from triton.language.extra import libdevice
@triton.jit
def pow_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = libdevice.pow(x,x+1)
tl.store(y_ptr + offsets, x, mask=mask)
x = torch.tensor([0.0, 1.1, 2.0, 3.0], dtype=torch.float32, device='cuda')
y = torch.zeros_like(x)
pow_kernel[(1,)](x, y, 4, BLOCK_SIZE=4)
print(y)
Output:
tensor([ 0.0000, 1.2216, 8.0000, 81.0000], device='cuda:0')