[Feature] No torch.sqrt support in Hidet ?
xxzh12 opened this issue · comments
I'm trying to optimize a SelfAttention
module, but there is no support for torch.sqrt
function. The code is as follows:
hidet.option.cache_dir('./outs/cache')
model = SelfAttention(num_attention_heads = 12, input_size = 768, hidden_size = 768, attention_probs_dropout_prob = 0.5, hidden_dropout_prob = 0.5).cuda().eval()
x = torch.rand(1, 128, 768).cuda()
# print(model)
model_opt = torch.compile(model, backend='hidet')
y = model_opt(x)
where I use
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
in LayNorm
module
The error information is as follows:
The following modules/functions are not supported by hidet yet: torch.sqrt
I'm wondering if there is any method to support torch.sqrt
function. I noticed that there is relevant abtraction in ir
for sqrt
function. However, the sqrt function in hidet\python\hidet\ir\primitives\math.py
presents raise NotImplementedError()
.
Hi @yaoyaoding,
Thanks for your kind reply! I will have a try.