onnx / models

A collection of pre-trained, state-of-the-art models in the ONNX format

Home Page:http://onnx.ai/models/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Divergence of accuracy in mobilenetv2_12_int8.onnx

Johansmm opened this issue · comments

Bug Report

Which model does this pertain to?

mobilenetv2-12-int8.onnx

Describe the bug

I am trying to achieve the same performance in mobilenetv2_12_int8.onnx, using pytorch to read imagenet dataset, onnxruntime to read model and torchmetrics to calculate accuracy. However, the only model which I have a significant accuracy drop is mobilenetv2_12_int8.onnx, reaching 64.346 % (vs 68.30 reporting on table https://github.com/onnx/models/tree/main/vision/classification/mobilenet#model).

Reproduction instructions

System Information

OS Platform and Distribution (Linux Ubuntu 20.04.4 LTS):
ONNX version (1.13.1):
Backend/Runtime version (ONNX Runtime 1.14.1, PyTorch 2.0.0):

Provide a code snippet to reproduce your errors.

import os
import io
import tarfile
from PIL import Image
from tqdm import tqdm

import torch
from torchvision import transforms as T
import torchmetrics

import onnxruntime

_TORCH_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


class ImagenetValDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, transform=None):
        images_path = os.path.join(img_dir, 'ILSVRC2012_img_val')
        try:
            self._tf = images_path + '.tar'
            with tarfile.open(self._tf) as tf:
                self._img_names = tf.getnames()
        except Exception as e:
            raise ValueError(f"{img_dir} have not 'ILSVRC2012_img_val.tar' "
                                "file or it is corrupted.") from e
        self._img_names = sorted(self._img_names)

        # Read labels
        self._labels = []
        with open(os.path.join(img_dir, 'imagenet_2012_validation_synset_labels.txt')) as f:
            while label := f.readline():
                self._labels.append(label.strip())
        self._label_names = sorted(set(self._labels))
        assert len(self._img_names) == len(self._labels), "Incomplete labels!"

        self.transform = transform

    def _get_image(self, name):
        image = self._tf.extractfile(name)
        image = io.BytesIO(image.read())
        image = Image.open(image).convert('RGB')
        return image

    def __len__(self):
        return len(self._img_names)

    def __getitem__(self, index):
        # Read tar file here to proper parallelization (just one time)
        if isinstance(self._tf, str):
            self._tf = tarfile.open(self._tf)

        # Read image from tar file
        image = self._get_image(self._img_names[index])

        # Apply transformation
        if self.transform is not None:
            image = self.transform(image)

        # Return image with his label
        label = self._labels[index]
        return image, self._label_names.index(label)

class OnnxInferencePipeline:
    def __init__(self, onnx_path):
        self._ort_session = onnxruntime.InferenceSession(onnx_path)

    @property
    def inputs(self):
        return self._ort_session.get_inputs()[0]

    @staticmethod
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    def __call__(self, inputs: torch.Tensor):
        # Generate ort inputs
        ort_inputs = {self.inputs.name: self.to_numpy(inputs)}

        # Run inputs in graph
        ort_outputs = self._ort_session.run(None, ort_inputs)
        return torch.from_numpy(ort_outputs[0]).to(inputs.device)


def get_imagenet_dataset(data_path, batch_size=128, image_size=224, num_workers=0):
    transform = T.Compose([T.Resize(int(image_size * 1.1429)),
                           T.CenterCrop(image_size),
                           T.ToTensor(),
                           T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    imagenet_data = ImagenetValDataset(data_path, transform=transform)
    return torch.utils.data.DataLoader(imagenet_data,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers)


def evaluate_model(model, dataset):
    print("Starting evaluation...")
    num_classes = len(dataset.dataset._label_names)
    accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    for images, gt_labels in (barprog := tqdm(dataset)):
        images = images.to(_TORCH_DEVICE)
        pred_labels = model(images).argmax(-1)
        acc_step = accuracy(pred_labels.cpu(), gt_labels)
        barprog.set_postfix({'acc': acc_step.item()})
    print(f"[INFO] Accuracy: {accuracy.compute()}")


if __name__ == "__main__":
    model_path = "mobilenetv2-12-int8.onnx"
    imagenet_path = "/imagenet/dataset"
    model = OnnxInferencePipeline(model_path)

    # Read dataset
    val_dataset = get_imagenet_dataset(imagenet_path, num_workers=8)

    # Process
    evaluate_model(model, val_dataset)

Notes

ImagenetValDataset needs the list of ordered-labels to work. If you need it, I could provide it.