Xilinx / brevitas

Brevitas: neural network quantization in PyTorch

Home Page:https://xilinx.github.io/brevitas/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Importing quantized models after bias correction

i-colbert opened this issue · comments

When quantizing floating-point models that don't have a bias in their layer (e.g., nn.Linear(in_features=10, out_features=2, bias=False)), bias correction currently will add a bias to the layer. This leads to the new bias being exported with the state dictionary. However, when loading this modified state dictionary in a new instance of original model, there is a missing keys error from pytorch because there is no bias in the floating-point (or even quantized model) without first running bias correction.

The issue can be resolved by first running bias correction before loading the modified state dictionary (see below), but a more flexible solution may be to add support into the state dictionary loading mechanism itself.

def _prepare_bias_corrected_quant_model(model: nn.Module):
    model.eval()
    dtype = next(model.parameters()).dtype
    device = next(model.parameters()).device
    images = torch.randn(10, 3, 32, 32)
    images = images.to(device)
    images = images.to(dtype)
    with torch.no_grad():
        with bias_correction_mode(model):
            model(images)