`AssertionError` when combining `BREVITAS_JIT=1` and `torch.compile` under PyTorch `v2.0.1`
nickfraser opened this issue · comments
While working on #785, I noticed that it appears to be impossible to combine a torch.ScriptMethod
with torch.compile
with PyTorch version v2.0.1
. An AssertionError
occurs during tracing where the type
of the forward method of a module is checked. If the forward method of a module is a torch.ScriptMethod
then the assertion will fail. Technically, this is an issue with PyTorch v2.0.x
, it has already been fixed in PyTorch v2.1.1
, where instead the _call_impl
of a module is checked, which is a types.MethodType
and not a torch.ScriptMethod
.
Currently, the only workaround is to suggest that users do not combine BREVITAS_JIT=1
and torch.compile
under PyTorch v2.0.1
, rather, they only use one or the other.
This is issue seems to be separate to the other PyTorch issue related to TorchScript
and TorchDynamo.