triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to use transpose in pytorch

ihaterecursion opened this issue · comments

commented

i just want replace transpose op in pytorch which written by triton:
aten_lib = torch.library.Library("aten", "IMPL")
def enable(lib=aten_lib):
lib.impl("transpose", transpose, "CUDA")

class use_triton:
def init(self):
self.lib = torch.library.Library("aten", "IMPL")

def __enter__(self):
    enable(self.lib)

def __exit__(self, exc_type, exc_val, exc_tb):
    del self.lib

def test():
with use_triton():
M = 1024
N = 1024
x = torch.rand((M,N), device='cuda', dtype=torch.float32)
y = torch.zeros((N,M), device='cuda', dtype=torch.float32)
print("hello")
y = torch.transpose(x,1,0)
print(y)

test()
but the fact is my transpose written by triton did not work,how to fix it?

Call .contiguous().