deBUGger404 / cifar-pytorch_model

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Train Basic Model on CIFAR10-Dataset

Contents

Introduction

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

below is the 6 random images with their respective label:

There is a package of python called torchvision, that has data loaders for CIFAR10 and data transformers for images using torch.utils.data.DataLoader.

Below an example of how to load CIFAR10 dataset using torchvision:

import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

Prerequisites

  • Python>=3.6
  • PyTorch >=1.4
  • Library are mentioned in requirenments.txt

Training

I used pretrained resnet18 for model training. you can use any other pretrained model according to you problem.

import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()

There are two things for pytorch model training:

  1. Notebook - you can just download and play with it
  2. python scripts:
    # Start training with: 
    python main.py
    
    # You can manually pass the attributes for the training: 
    python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth'
    
    # Start infrence with:
     python3.6 prediction.py --model_path './cifar_model.pth'
    

Give a ⭐ to this Repository!

About

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.


Languages

Language:Jupyter Notebook 97.6%Language:Python 2.4%