erikreppel / capsulenet

An implementation of Hinton's capsnet in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

πŸ’Š CapsuleNet πŸ’Š

A PyTorch implementation of CapsuleNet as described in "Dynamic Routing Between Capsules" by Hinton et al.

Includes decoder and pretrained weights for MNIST and Fashion MNIST in checkpoints/.


This project uses conda for package management, install miniconda3 here.

Create an environment with all the dependancies:

conda env create -f environment.yml

Activate with

source activate capsnet

To train:

β–Ά python --help
usage: CapsNet [-h] [--epochs EPOCHS] [--data_path DATA_PATH]
               [--batch_size BATCH_SIZE] [--use_gpu] [--lr LR]
               [--log_interval LOG_INTERVAL] [--visdom] [--dataset DATASET]
               [--load_checkpoint LOAD_CHECKPOINT]
               [--checkpoint_interval CHECKPOINT_INTERVAL]
               [--checkpoint_dir CHECKPOINT_DIR] [--gen_dir GEN_DIR]

Example of CapsNet

optional arguments:
  -h, --help            show this help message and exit
  --epochs EPOCHS
  --data_path DATA_PATH
  --batch_size BATCH_SIZE
  --lr LR               ADAM learning rate (0.01)
  --log_interval LOG_INTERVAL
                        number of batches between logging
  --visdom              Whether or not to use visdom for plotting progrss
  --dataset DATASET     The dataset to train on, currently supported: MNIST,
                        Fashion MNIST
  --load_checkpoint LOAD_CHECKPOINT
                        path to load a previously trained model from
  --checkpoint_interval CHECKPOINT_INTERVAL
                        path to load a previously trained model from
  --checkpoint_dir CHECKPOINT_DIR
                        dir to store checkpoints in
  --gen_dir GEN_DIR     folder to store generated images in
python --visdom --checkpoint_interval=1 --epochs=10

Visdom for graphing progress

To start the visdom server:

python -m visdom.server
# Now running on http://localhost:8097/

Then run with the --visdom flag

To run using Fashion mnist

Download the dataset from here, place them in a subdirectory of folder entitled raw and run with

python --data_path=<path to download> ...

your fashion-mnist folder should look like this

β–Ά tree ~/data/fashion-mnist
β”œβ”€β”€ processed
β”‚   β”œβ”€β”€
β”‚   └──
└── raw
    β”œβ”€β”€ t10k-images-idx3-ubyte
    β”œβ”€β”€ t10k-labels-idx1-ubyte
    β”œβ”€β”€ train-images-idx3-ubyte
    └── train-labels-idx1-ubyte

2 directories, 6 files

The processed folder is created after training starts.

The PyTorch mnist dataset class will handle pre-processing.


Dataset Epochs Test loss Test accuracy
MNIST 10 0.04356 98.803
Fashion MNIST 10 0.19429 86.580
MNIST 50 0.03029 99.011
Fashion MNIST 50 0.16416 88.904


Conv layer:
- input channels: 1
- output channels: 256
- stride: 1
- kernel size: 9x9

Capsule layer 1:
- 8 capsules of size 1152
- input channels: 256
- output channels: 32

Capsule layer 2:
- 10 capsules of size 16 (10 classes in mnist)
- input channels: 32
- 3 iterations of the routing algorithm

Generated images

I have yet to be able to reproduce the sharpness of reproduced images from the paper, I suspect it the reason is because I am decoupling the digit cap results from so that loss from the image generation is not backproped into capsnet.

Results of the decoder:

batch batch

Comparison of original and generated:


Training Graphs

Results of training for 10 epochs on MNIST:

train loss train acc Test Loss train acc

Results of training for 10 epochs on Fashion MNIST:

train loss train acc Test Loss train acc


I found these implementations useful when I got stuck


An implementation of Hinton's capsnet in pytorch


Language:Python 76.7%Language:Jupyter Notebook 23.3%