pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Make fused RMSNorm a registered op

lessw2020 opened this issue · comments

Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:

making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism

update: Hit IMA issues for both my implementation #296 and @wconstab's #303. Working on debugging with @lessw2020 .

closing this as we have supported this fused RMSNorm in Tensor Parallelism (#404).