Divergence of accuracy in mobilenetv2_12_int8.onnx
Johansmm opened this issue · comments
Bug Report
Which model does this pertain to?
Describe the bug
I am trying to achieve the same performance in mobilenetv2_12_int8.onnx
, using pytorch to read imagenet dataset, onnxruntime
to read model and torchmetrics
to calculate accuracy. However, the only model which I have a significant accuracy drop is mobilenetv2_12_int8.onnx
, reaching 64.346 % (vs 68.30 reporting on table https://github.com/onnx/models/tree/main/vision/classification/mobilenet#model).
Reproduction instructions
System Information
OS Platform and Distribution (Linux Ubuntu 20.04.4 LTS):
ONNX version (1.13.1):
Backend/Runtime version (ONNX Runtime 1.14.1, PyTorch 2.0.0):
Provide a code snippet to reproduce your errors.
import os
import io
import tarfile
from PIL import Image
from tqdm import tqdm
import torch
from torchvision import transforms as T
import torchmetrics
import onnxruntime
_TORCH_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
class ImagenetValDataset(torch.utils.data.Dataset):
def __init__(self, img_dir, transform=None):
images_path = os.path.join(img_dir, 'ILSVRC2012_img_val')
try:
self._tf = images_path + '.tar'
with tarfile.open(self._tf) as tf:
self._img_names = tf.getnames()
except Exception as e:
raise ValueError(f"{img_dir} have not 'ILSVRC2012_img_val.tar' "
"file or it is corrupted.") from e
self._img_names = sorted(self._img_names)
# Read labels
self._labels = []
with open(os.path.join(img_dir, 'imagenet_2012_validation_synset_labels.txt')) as f:
while label := f.readline():
self._labels.append(label.strip())
self._label_names = sorted(set(self._labels))
assert len(self._img_names) == len(self._labels), "Incomplete labels!"
self.transform = transform
def _get_image(self, name):
image = self._tf.extractfile(name)
image = io.BytesIO(image.read())
image = Image.open(image).convert('RGB')
return image
def __len__(self):
return len(self._img_names)
def __getitem__(self, index):
# Read tar file here to proper parallelization (just one time)
if isinstance(self._tf, str):
self._tf = tarfile.open(self._tf)
# Read image from tar file
image = self._get_image(self._img_names[index])
# Apply transformation
if self.transform is not None:
image = self.transform(image)
# Return image with his label
label = self._labels[index]
return image, self._label_names.index(label)
class OnnxInferencePipeline:
def __init__(self, onnx_path):
self._ort_session = onnxruntime.InferenceSession(onnx_path)
@property
def inputs(self):
return self._ort_session.get_inputs()[0]
@staticmethod
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def __call__(self, inputs: torch.Tensor):
# Generate ort inputs
ort_inputs = {self.inputs.name: self.to_numpy(inputs)}
# Run inputs in graph
ort_outputs = self._ort_session.run(None, ort_inputs)
return torch.from_numpy(ort_outputs[0]).to(inputs.device)
def get_imagenet_dataset(data_path, batch_size=128, image_size=224, num_workers=0):
transform = T.Compose([T.Resize(int(image_size * 1.1429)),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
imagenet_data = ImagenetValDataset(data_path, transform=transform)
return torch.utils.data.DataLoader(imagenet_data,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
def evaluate_model(model, dataset):
print("Starting evaluation...")
num_classes = len(dataset.dataset._label_names)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
for images, gt_labels in (barprog := tqdm(dataset)):
images = images.to(_TORCH_DEVICE)
pred_labels = model(images).argmax(-1)
acc_step = accuracy(pred_labels.cpu(), gt_labels)
barprog.set_postfix({'acc': acc_step.item()})
print(f"[INFO] Accuracy: {accuracy.compute()}")
if __name__ == "__main__":
model_path = "mobilenetv2-12-int8.onnx"
imagenet_path = "/imagenet/dataset"
model = OnnxInferencePipeline(model_path)
# Read dataset
val_dataset = get_imagenet_dataset(imagenet_path, num_workers=8)
# Process
evaluate_model(model, val_dataset)
Notes
ImagenetValDataset
needs the list of ordered-labels to work. If you need it, I could provide it.