hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.

Home Page:https://hidet.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature] No torch.sqrt support in Hidet ?

xxzh12 opened this issue · comments

commented

I'm trying to optimize a SelfAttention module, but there is no support for torch.sqrt function. The code is as follows:

hidet.option.cache_dir('./outs/cache')
model = SelfAttention(num_attention_heads = 12, input_size = 768, hidden_size = 768, attention_probs_dropout_prob = 0.5, hidden_dropout_prob = 0.5).cuda().eval()
x = torch.rand(1, 128, 768).cuda()
# print(model)
model_opt = torch.compile(model, backend='hidet')  
y = model_opt(x)

where I use

x = (x - u) / torch.sqrt(s + self.variance_epsilon)

in LayNorm module
The error information is as follows:

The following modules/functions are not supported by hidet yet: torch.sqrt

I'm wondering if there is any method to support torch.sqrt function. I noticed that there is relevant abtraction in ir for sqrt function. However, the sqrt function in hidet\python\hidet\ir\primitives\math.py presents raise NotImplementedError().

Hi @xxzh12,

#387 adds the operator mapping for torch.sqrt. I do not have the defintion of SelfAttention thus I did not test on your use case. Feel free to open another issue if there are other operators are not mapped. Thanks.

commented

Hi @yaoyaoding,

Thanks for your kind reply! I will have a try.