Capsule Network in PyTorch
This repository aims to implement the following Capsule Network in PyTorch and has reproduced the same performance claimed in the papers:
- "Dynamic Routing Between Capsules" by Sara Sabour, Nickolas Frosst, Geoffrey Hinton [paper]
- "Matrix Capsule with EM Routing" by Geoffrey Hinton, Sara Sabour, Nickolas Frosst [paper] (TBD)
Official repository
Requirements
- PyTorch >= 1.3.0
Train
Train Capsule Network with decoder by using margin loss:
python train.py
Train Capsule Network without decoder by using margin loss:
-
rewrite _config in *__init__.py* under configuration folder as following:
_config = { …… 'model': 'CapsNet', …… }
-
python train.py
Train CNN Baseline by using cross entropy loss:
-
rewrite _config in *__init__.py* under configuration folder as following:
_config = { …… 'model': 'BaselineNet', …… 'criterion': 'ce', …… }
-
python train.py
Experiments
Classification test accuracy on MNIST with learning rate schedule.
#iter | batch size | #epoch | test error (%) | criterion | |
---|---|---|---|---|---|
CNN Baseline | - | 128 | 5000 | 0.32 | Cross Entropy |
CapsuleNet w/ Decoder | 3 | 128 | 5000 | 0.25 | Margin Loss |
Classification test accuracy on MNIST without learning rate schedule.
#iter | batch size | #epoch | test error (%) | criterion | |
---|---|---|---|---|---|
CNN Baseline | - | 8 | 10 | 0.58 | Cross Entropy |
CapsuleNet w/o Decoder | 3 | 8 | 10 | 0.86 | Cross Entropy |
CapsuleNet w/o Decoder | 3 | 8 | 10 | 0.74 | Margin Loss |
CapsuleNet w/ Decoder | 3 | 8 | 10 | 0.78 | Margin Loss |