CauchyFood / image-gpt

PyTorch Implementation of OpenAI's Image GPT

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.

Model-generated completions of half-images from test set. First column is input; last column is original image

iGPT-S pretrained on CIFAR10. Completions are fairly poor as the model was only trained on CIFAR10, not all of ImageNet.


  • Batched k-means on GPU for quantization of larger datasets (currently using sklearn.cluster.MiniBatchKMeans.)
  • BERT-style pretraining (currently only generative is supported.)
  • Load pretrained models from OpenAI.
  • Reproduce at least iGPT-S results.

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.


Pre-trained Models

Some pre-trained models are located in models directory. Run ./ to download the cifar10 pretrained iGPT-S model.

Compute Centroids

Images are downloaded, and centroids are computed using k-means with num_clusters clusters. These centroids are used to quantize the images before they are fed into the model.

# options: mnist, fmnist, cifar10
python src/ --dataset mnist --num_clusters=8

# creates data/<dataset>_centroids.npy

Note: Use the same num_clusters as num_vocab in your model.


Models can be trained using src/ with the train subcommand.

Generative Pre-training

Models can be pretrained by specifying a dataset and model config. configs/s_gen.yml corresponds to iGPT-S from the paper, configs/xxs_gen.yml is an extra small model for trying on toy datasets with limited compute.

python src/ --dataset mnist train configs/xxs_gen.yml

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the config file and dataset.

python src/ --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt`


Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/ models/mnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/ models/mnist_gen.ckpt


PyTorch Implementation of OpenAI's Image GPT

License:Apache License 2.0


Language:Python 96.9%Language:Shell 3.1%