Importing quantized models after bias correction
i-colbert opened this issue · comments
When quantizing floating-point models that don't have a bias in their layer (e.g., nn.Linear(in_features=10, out_features=2, bias=False)
), bias correction currently will add a bias to the layer. This leads to the new bias being exported with the state dictionary. However, when loading this modified state dictionary in a new instance of original model, there is a missing keys error from pytorch because there is no bias in the floating-point (or even quantized model) without first running bias correction.
The issue can be resolved by first running bias correction before loading the modified state dictionary (see below), but a more flexible solution may be to add support into the state dictionary loading mechanism itself.
def _prepare_bias_corrected_quant_model(model: nn.Module):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
images = torch.randn(10, 3, 32, 32)
images = images.to(device)
images = images.to(dtype)
with torch.no_grad():
with bias_correction_mode(model):
model(images)