bdsaglam / torch-scae

PyTorch implementation of Stacked Capsule Auto-Encoders

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch-scae

PyTorch implementation of Stacked Capsule Auto-Encoders [1].

Ported from official implementation with TensorFlow v1. The architecture of model and hyper-parameters are kept same. However, some parts are refactored for ease of use.

Please, open an issue for bugs and inconsistencies with original implementation.

⚠️: The performance of this implementation is inferior than the original due to an unknown bug. There is already an open issue for this, but it has been resolved yet.


Installation

# clone project   
git clone https://github.com/bdsaglam/torch-scae   

# install project   
cd torch-scae
pip install -e .

Train with MNIST Open In Colab

It uses PyTorch Lightning for training and Hydra for configuration management.

# CPU
python -m torch_scae_experiments.mnist.train

# GPU
python -m torch_scae_experiments.mnist.train +trainer.gpus=1

You can customize model hyperparameters and training with Hydra syntax.

python -m torch_scae_experiments.mnist.train \
    data_loader.batch_size=32 \
    optimizer.learning_rate=1e-4 \
    model.n_part_caps=16 \
    trainer.max_epochs=100 

Results

Image reconstructions

After training for 5 epochs

logo

Fig 1. Rows: original image, bottom-up reconstructions and top-down reconstructions

References

  1. Kosiorek, A. R., Sabour, S., Teh, Y. W., & Hinton, G. E. (2019). Stacked Capsule Autoencoders. NeurIPS. http://arxiv.org/abs/1906.06818

About

PyTorch implementation of Stacked Capsule Auto-Encoders

License:Apache License 2.0


Languages

Language:Python 93.9%Language:Jupyter Notebook 6.1%