pytorch / glow

Compiler for Neural Network hardware accelerators

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

-quantization-precision=Int16 cause no quantization at all

jakubmiernik opened this issue · comments

Hello.

Seems like there is and issue during compiling model by model-compiler with -quantization-precision=Int16 option. When trying to compile with int16 precision model is fully float after compilation, all quantization is ignored.

Reproduction:

  1. Export model to ONNX:
import torch
import torch.nn as nn
import numpy as np

class test_CNN(nn.Module):
    def __init__(self, drop=0.25):
        super(test_CNN, self).__init__()
        self.conv1 = nn.Conv2d(2, 16, (3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = torch.nn.ReLU()
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = x.reshape([-1, 1024])
        return x

model = test_CNN()
x = torch.rand(1, 2, 8, 8, requires_grad=True)

onnx_path = "./test_nn.onnx"

torch.onnx.export(model,
                  x,
                  onnx_path,
                  export_params=True,
                  opset_version=10,
                  do_constant_folding=True,
                  input_names=['input'],
                  output_names=['output'])
  1. Profile model (prepare some random input before in nn_profile of size 128):
./bin/model-profiler -model=../test_nn.onnx -dump-profile="profile.yml" -input-dataset="input,rawtxt,dir,nn_profile"
  1. Compile model:
./bin/model-compiler --verbose-compilation -backend=CPU -model=../test_nn.onnx -emit-bundle=build_test -model-input="input,float,[1,2,8,8]" --bundle-api-verbose -onnx-define-symbol="batch_size,1" --load-profile=profile.yml -quantization-schema=asymmetric -quantization-precision=Int16

After compilation model is the fully floating point:

// ---------------------------------------------------------------
//                          Bundle API
// ---------------------------------------------------------------
// Model name: "test_nn"
// Total data size: 18624 (bytes)
// Activations allocation efficiency: 1.0000
// Placeholders:
//
//   Name: "input"
//   Type: float<1 x 2 x 8 x 8>
//   Size: 128 (elements)
//   Size: 512 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "output"
//   Type: float<1 x 1024>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 512 (bytes)
//
// Constants:
//
//   Name: "conv1_bias"
//   Type: float<16>
//   Size: 16 (elements)
//   Size: 64 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "conv1_weight__1"
//   Type: float<16 x 3 x 3 x 2>
//   Size: 288 (elements)
//   Size: 1152 (bytes)
//   Offset: 64 (bytes)
//
//   Name: "BatchNormalization_2__1_muBroadcasted_tile2__1_constfold"
//   Type: float<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 1216 (bytes)
//
//   Name: "BatchNormalization_2__1_coefBroadcasted_tile2__1_constfold"
//   Type: float<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 5312 (bytes)
//
// NOTE: Placeholders are allocated within the "mutableWeight"
// buffer and are identified using an offset relative to base.
// ---------------------------------------------------------------

For comparison, model compiled without -quantization-precision=Int16 is correctly compiled to int8 model:

// ---------------------------------------------------------------
//                          Bundle API
// ---------------------------------------------------------------
// Model name: "test_nn"
// Total data size: 12416 (bytes)
// Activations allocation efficiency: 1.0000
// Placeholders:
//
//   Name: "input"
//   Type: float<1 x 2 x 8 x 8>
//   Size: 128 (elements)
//   Size: 512 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "output"
//   Type: float<1 x 1024>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 512 (bytes)
//
// Constants:
//
//   Name: "conv1_bias"
//   Type: i32[S:0.000014470 O:0][-31073.324,31073.324]<16>
//   Size: 16 (elements)
//   Size: 64 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "bn1_weight"
//   Type: i8[S:0.003921569 O:-128][0.000,1.000]<16>
//   Size: 16 (elements)
//   Size: 16 (bytes)
//   Offset: 64 (bytes)
//
//   Name: "conv1_weight__1"
//   Type: i8[S:0.001844879 O:-1][-0.234,0.236]<16 x 3 x 3 x 2>
//   Size: 288 (elements)
//   Size: 288 (bytes)
//   Offset: 128 (bytes)
//
//   Name: "BatchNormalization_2__1_coefBroadcasted_tile2__1_quantize__5_constfold"
//   Type: i8[S:0.003921549 O:-128][0.000,1.000]<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 1024 (bytes)
//   Offset: 448 (bytes)
//
//   Name: "BatchNormalization_2__1_coefBroadcasted_tile1__1_quantize__5_constfold"
//   Type: i8[S:0.003921549 O:-128][0.000,1.000]<1 x 8 x 1 x 16>
//   Size: 128 (elements)
//   Size: 128 (bytes)
//   Offset: 1472 (bytes)
//
//   Name: "BatchNormalization_2__1_muBroadcasted_tile2__1_constfold"
//   Type: i8[S:0.100000001 O:0][-12.800,12.700]<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 1024 (bytes)
//   Offset: 1600 (bytes)
//
//   Name: "BatchNormalization_2__1_sqrt_var_plus_eps__1_constfold"
//   Type: float<16>
//   Size: 16 (elements)
//   Size: 64 (bytes)
//   Offset: 2624 (bytes)
//
// NOTE: Placeholders are allocated within the "mutableWeight"
// buffer and are identified using an offset relative to base.
// ---------------------------------------------------------------

@jakubmiernik The CPU backend doesn't really have support for Int16 quantized types. IIRC if the backend doesn't support quantization for specific ops with that precision then it will skip quantizing the op to that precision.

@jfix71 Thanks for the answer.

Few questions in this topic:

  1. Is there some target that has int16 nodes implemented?
  2. Is quantization to int16 done? I mean whole quantization of weights process, types etc
  3. Implementing int16 layers would cause this option to work?

Is there some target that has int16 nodes implemented?

Nothing that is in the mainline Glow open source repo, AFAIK. I believe it's currently only used for some specialized operators for e.g. the bias of some ops.

Is quantization to int16 done? I mean whole quantization of weights process, types etc

The int16 quantization flow uses the same infra as the int8 flow. They should behave the same, but again this process queries the backend to try to determine what to quantize based on what is supported.

Implementing int16 layers would cause this option to work?

Yes, you would need to add kernels support to libjit, as well as update isOpSupported in the CPUBackend/LLVMBackend to represent these precisions are supported for the ops you add support for.

@jfix71 I started digging into it. I added convolution using int16, but something seems to not work fine.

I dig some more and I'm wondering if quantization to int16 is for sure implemented. I found stuff like quantizeScaleOffset32To8 to calculate scale for outputs, but it's only for int8, not for int16.

Could you also point me to some tests of quantized layers if there are any? (I found only TEST_P(BackendCorrectnessTest, quantizedConvTest) but maybe there is something better)