Xilinx / brevitas

Brevitas: neural network quantization in PyTorch

Home Page:https://xilinx.github.io/brevitas/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`AssertionError` when combining `BREVITAS_JIT=1` and `torch.compile` under PyTorch `v2.0.1`

nickfraser opened this issue · comments

While working on #785, I noticed that it appears to be impossible to combine a torch.ScriptMethod with torch.compile with PyTorch version v2.0.1. An AssertionError occurs during tracing where the type of the forward method of a module is checked. If the forward method of a module is a torch.ScriptMethod then the assertion will fail. Technically, this is an issue with PyTorch v2.0.x, it has already been fixed in PyTorch v2.1.1, where instead the _call_impl of a module is checked, which is a types.MethodType and not a torch.ScriptMethod.

Currently, the only workaround is to suggest that users do not combine BREVITAS_JIT=1 and torch.compile under PyTorch v2.0.1, rather, they only use one or the other.

This is issue seems to be separate to the other PyTorch issue related to TorchScript and TorchDynamo.