support inplace operator in TorchBlade
Yancey1989 opened this issue · comments
Yan Xu commented
Pytorch inplace operators are widely used in model training/inference, just as KV cache in LLM, TorchBlade should support these
inplace operator.
Just as the following simple PyTorch program with in-place add operator aten.add_
:
def func(add : Tensor, value : Tensor) -> Tensor:
add.add_(value)
return add
torchdynamo can trace the above program into value-semantic operators and insert aten._copy
at the back:
add: f32[8, 64] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg1_1 = None
# No stacktrace found for the following nodes
copy_: f32[8, 64] = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None
The naive operation in buffer level can be represent as:
lmhlo.add %arg0, %arg1, %1
lmhlo.copy %1, %arg0
but in actuality the lmhlo.copy
causes overhead because of the necessary d2d memcpy, we can remove this copy and store the result element on the input buffer:
lmhlo.add %arg0, %arg1, %arg0