KimRass / VQ-VAE

PyTorch implementation of VQ-VAE (Oord et al., 2017) and training it on Fashion MNIST and CIFAR-10

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

1. Pre-trained Models

1) On Fashion MNIST

  • vqvae-fashion_mnist.pth
  • Trained VQ-VAE for 47 epochs. (Validation loss: 0.145)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 14 epochs. (Validataion loss: 1.279)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2) On CIFAR-10

  • vqvae-cifar10.pth
  • Trained VQ-VAE for 40 epochs. (Validation loss: 0.139)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 96 epochs. (Validataion loss: 2.226)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2. Samples

Fashion MNIST
CIFAR-10

3. Implementation Details

1) detach()

  • VQ-VAE 학습에서 Loss 계산 시 z_q = z_e + (z_q - z_e).detach()를 추가할 시 학습이 더 빨라지는 것을 확인했으나, 정확히 어떤 기능을 하는지까지는 알지 못했습니다.

About

PyTorch implementation of VQ-VAE (Oord et al., 2017) and training it on Fashion MNIST and CIFAR-10


Languages

Language:Python 84.9%Language:Shell 7.8%Language:Jupyter Notebook 7.3%