bshall / VectorQuantizedVAE

A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Vector Quantized VAE

A PyTorch implementation of Continuous Relaxation Training of Discrete Latent Variable Image Models.

Ensure you have Python 3.7 and PyTorch 1.2 or greater. To train the VQVAE model with 8 categorical dimensions and 128 codes per dimension run the following command:

python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128

To train the GS-Soft model use --model=GSSOFT. Pretrained weights for the VQVAE and GS-Soft models can be found here.

VQVAE Reconstructions

The VQVAE model gets ~4.82 bpd while the GS-soft model gets ~4.6 bpd.

Analysis of the Codebooks

As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions:

Explained Variance Ratio

Projecting the codes onto the first 3 principal components shows that the codes typically tile continuous 1- or 2-D manifolds:

Codebook principal components

About

A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"

License:MIT License


Languages

Language:Jupyter Notebook 96.0%Language:Python 4.0%