Fused RMSNorm incompatible with PP tracing (dynamic stride)
wconstab opened this issue · comments
The incompatibility is that during backwards, fused_rmsnorm does dynamic control flow over strides, which isn't safe for export tracing used by PP.
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
Which leads to a stacktrace ending in
File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 326, in var_getattr
unimplemented(f"Illegal getattr invocation {name} in strict mode")
File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 204, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
Would it be possible to refactor this in a more export friendly way, or is that difficult?
cc @lessw2020, @kwen2501
short term is the stride check can be removed to explore tracing (this check is rarely needed, confirmed on llama_7b).
Longer term this will either need a refactor to support dynamic strides (harder) or given the rarity, just a simple assert that we don't support non-contiguous.
I did not look into this closely, but could we rely on .contiguous()
being a no-op if already contiguous and remove the stride check? (There might be ever-so-slightly more CPU overhead if there is a Python <> C++ switch from .contiguous()
, but I think this should be okay for our purpose.)