scotwilli / image-gpt

Pytorch Implementation of OpenAI's Image GPT, trained on MNIST and Fashion MNIST

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.


Model-generated completions of half-images from test set. First column is input; last column is original image

Differences from original paper:

  • Uses 4-bit grayscale images instead of 9-bit RGB
  • 28x28 images are used instead of 32x32
  • Quantization is done naively using division, not KNN
  • Model is much smaller and can be trained with much less compute

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.

Usage

Pre-trained Models

Pre-trained models are located in models directory.

Prepare Data

To download and prepare data, run src/prepare_data.py. Omitting the --fashion argument will download normal MNIST. Images are downloaded and encoded with a 4-bit grayscale pallete.

python src/prepare_data.py --fashion

Training

Models can be trained using src/run.py with the train subcommand.

Generative Pre-training

python src/run.py train --name fmnist_gen

The following hyperparameters can also be provided. Smallest model from paper is shown for comparison.

Argument Default iGPT-S (Chen et al.)
--embed_dim 16 512
--num_heads 2 8
--num_layers 8 24
--num_pixels 28 32
--num_vocab 16 512
--batch_size 64 128
--learning_rate 0.01 0.01
--steps 25000 1000000

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the --classify argument. I have found a small reduction in learning rate is necessary.

python src/run.py train \
    --name fmnist_clf  \
    --pretrained models/fmnist_gen.ckpt \
    --classify \
    --learning_rate 3e-3

Sampling

Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/sample.py models/fmnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/gif.py models/fmnist_gen.ckpt

About

Pytorch Implementation of OpenAI's Image GPT, trained on MNIST and Fashion MNIST


Languages

Language:Python 100.0%