daquexian / onnx-simplifier

Simplify your onnx model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support conv-bn fold with QDQ node inserted and bn-conv fold

tp-nan opened this issue · comments

commented

Is it possible to support conv-bn fold with QDQ node inserted and bn-conv fold ?

from pytorch_quantization import quant_modules
quant_modules.initialize()

class PreBN(nn.Module):
    def __init__(self, num_channels):
        super(PreBN, self).__init__()
        


        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

        # self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        
        out = self.relu1((x)+out)
        out = self.relu2(self.conv2(self.bn2(out)))
        return out
Screenshot 2023-07-03 at 16 16 16
commented

Hi @ShiyangZhang, if you export a torch model into ONNX format with trainingMode.EVAL (by default) mode, it will fuse the conv+bn.
torch.onnx.export(..., training=TrainingMode.EVAL,...)
You can also refer the related discussion. pytorch/pytorch#49226 (comment)

Does the above ONNX model generate by pytorch_quantization?

commented

Does the above ONNX model generate by pytorch_quantization?

The above ONNX model generate by torch.onnx.export.

We changed the pytorch model's definition.
pytorch_quantization is used for replacing torch.nn.Conv to a quantized version: QuantConv2d, in which QuantizeLinear and DequantizeLinear nodes(torch.fake_quantize_per_tensor_affine) are inserted. So the weight of Conv has been quantized, and also the input.

class QuantConv2d(_QuantConvNd):
    """Quantized 2D conv"""

    default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 padding_mode='zeros',
                 **kwargs):

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        quant_desc_input, quant_desc_weight = _utils.pop_quant_desc_in_kwargs(self.__class__, **kwargs)
        super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False,
                                          _pair(0), groups, bias, padding_mode,
                                          quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)

    def forward(self, input):
        # the actual quantization happens in the next level of the class hierarchy
        quant_input, quant_weight = self._quant(input)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),
                              quant_weight, self.bias, self.stride,
                              _pair(0), self.dilation, self.groups)
        else:
            output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,
                              self.groups)

        return output
commented

I understand. It seems that the torch.onnx conv+bn fusion pass examines the inputs of the conv and merges the BN into Conv if the conv's weight is a constant. The expected pattern should be as follows:

				conv_weight
					/
	       	      \	             /
			\	  /
			[ Conv]
			   |
			Conv_out
			   |
			[ BN ]
			   |
	       	conv_out_after_bn
		           |
			  ...

Once QDQ is inserted, the weight of the conv becomes the output tensor of Dequantize which does not conform to the pattern.
The fusion of conv+bn updates the weight based on the BN parameters. However, once QDQ is added, the weight is dependent on Q/QD, which makes sense not to fuse it.

To fuse BN + Conv, how about exporting the torch model into ONNX format first (the conv+bn fusion will be done at this stage), and then do use the ONNX quantization tools to quantize it?

commented

To fuse BN + Conv, how about exporting the torch model into ONNX format first (the conv+bn fusion will be done at this stage), and then do use the ONNX quantization tools to quantize it?

Thanks for your reply! That's a good advise for Post Training Quantization. In the context of Quantization Aware Training (QAT), does the scale factor for Batch Normalization can be merged into the scale factor of the quantize and dequantize layers for weight of conv during inference ?

To fuse BN + Conv, how about exporting the torch model into ONNX format first (the conv+bn fusion will be done at this stage), and then do use the ONNX quantization tools to quantize it?

Thanks for your reply! That's a good advise for Post Training Quantization. In the context of Quantization Aware Training (QAT), does the scale factor for Batch Normalization can be merged into the scale factor of the quantize and dequantize layers for weight of conv during inference ?

I faced the same problem, did you get the any answer?

Also, I think TensorRT will automatically fuse QDQ nodes with Conv and BN ops. So I don't know what the benefit would be of doing this.