Mammoth - An Extendible (General) Continual Learning Framework for Pytorch
Setup
- Use
./utils/main.py
to run experiments. - Use argument
--load_best_args
to use the best hyperparameters from the paper. - New models can be added to the
models/
folder. - New datasets can be added to the
datasets/
folder.
Models
- Gradient Episodic Memory (GEM)
- A-GEM
- A-GEM with Reservoir (A-GEM-R)
- Experience Replay (ER)
- Meta-Experience Replay (MER)
- Function Distance Regularization (FDR)
- Greedy gradient-based Sample Selection (GSS)
- Hindsight Anchor Learning (HAL)
- Incremental Classifier and Representation Learning (iCaRL)
- online Elastic Weight Consolidation (oEWC)
- Synaptic Intelligence
- Learning without Forgetting
- Progressive Neural Networks
- Dark Experience Replay (DER)
- Dark Experience Replay++ (DER++)
Datasets
Class-Il / Task-IL settings
- Sequential MNIST
- Sequential CIFAR-10
- Sequential Tiny ImageNet
Domain-IL settings
- Permuted MNIST
- Rotated MNIST
General Continual Learning setting
- MNIST-360
Results
Continual Learning Results | |||||||||
---|---|---|---|---|---|---|---|---|---|
Buffer | Method | S-CIFAR-10 | S-Tiny-ImageNet | P-MNIST | R-MNIST | S-MNIST | |||
Class-IL | Task-IL | Class-IL | Task-IL | Domain-IL | Domain-IL | Class-IL | Task-IL | ||
- | JOINT | 92.20 | 98.31 | 59.99 | 82.04 | 94.33 | 95.76 | 95.57 | 99.51 |
SGD | 19.62 | 61.02 | 7.92 | 18.31 | 40.70 | 67.66 | 19.60 | 94.94 | |
oEWC | 19.49 | 68.29 | 7.58 | 19.20 | 75.79 | 77.35 | 20.46 | 98.39 | |
SI | 19.48 | 68.05 | 6.58 | 36.32 | 65.86 | 71.91 | 19.27 | 96.00 | |
LwF | 19.61 | 63.29 | 8.46 | 15.85 | - | - | 19.62 | 94.11 | |
PNN | - | 95.13 | - | 67.84 | - | - | - | 99.23 | |
200 | ER | 44.79 | 91.19 | 8.49 | 38.17 | 72.37 | 85.01 | 80.43 | 97.86 |
MER | - | - | - | - | - | - | 81.47 | 98.05 | |
GEM | 25.54 | 90.44 | - | - | 66.93 | 80.80 | 80.11 | 97.78 | |
A-GEM | 20.04 | 83.88 | 8.07 | 22.77 | 66.42 | 81.91 | 45.72 | 98.61 | |
iCaRL | 49.02 | 88.99 | 7.53 | 28.19 | - | - | 70.51 | 98.28 | |
FDR | 30.91 | 91.01 | 8.70 | 40.36 | 74.77 | 85.22 | 79.43 | 97.66 | |
GSS | 39.07 | 88.80 | - | - | 63.72 | 79.50 | 38.90 | 95.02 | |
HAL | 32.36 | 82.51 | - | - | 74.15 | 84.02 | 84.70 | 97.96 | |
DER | 61.93 | 91.40 | 11.87 | 40.22 | 81.74 | 90.04 | 84.55 | 98.80 | |
DER++ | 64.88 | 91.92 | 10.96 | 40.87 | 83.58 | 90.43 | 85.61 | 98.76 | |
500 | ER | 57.74 | 93.61 | 9.99 | 48.64 | 80.60 | 88.91 | 86.12 | 99.04 |
MER | - | - | - | - | - | - | 88.35 | 98.43 | |
GEM | 26.20 | 92.16 | - | - | 76.88 | 81.15 | 85.99 | 98.71 | |
A-GEM | 22.67 | 89.48 | 8.06 | 25.33 | 67.56 | 80.31 | 46.66 | 98.93 | |
iCaRL | 47.55 | 88.22 | 9.38 | 31.55 | - | - | 70.10 | 98.32 | |
FDR | 28.71 | 93.29 | 10.54 | 49.88 | 83.18 | 89.67 | 85.87 | 97.54 | |
GSS | 49.73 | 91.02 | - | - | 76.00 | 81.58 | 49.76 | 97.71 | |
HAL | 41.79 | 84.54 | - | - | 80.13 | 85.00 | 87.21 | 98.03 | |
DER | 70.51 | 93.40 | 17.75 | 51.78 | 87.29 | 92.24 | 90.54 | 98.84 | |
DER++ | 72.70 | 93.88 | 19.38 | 51.91 | 88.21 | 92.77 | 91.00 | 98.94 | |
5120 | ER | 82.47 | 96.98 | 27.40 | 67.29 | 89.90 | 93.45 | 93.40 | 99.33 |
MER | - | - | - | - | - | - | 94.57 | 99.27 | |
GEM | 25.26 | 95.55 | - | - | 87.42 | 88.57 | 95.11 | 99.44 | |
A-GEM | 21.99 | 90.10 | 7.96 | 26.22 | 73.32 | 80.18 | 54.24 | 98.93 | |
iCaRL | 55.07 | 92.23 | 14.08 | 40.83 | - | - | 70.60 | 98.32 | |
FDR | 19.70 | 94.32 | 28.97 | 68.01 | 90.87 | 94.19 | 87.47 | 97.79 | |
GSS | 67.27 | 94.19 | - | - | 82.22 | 85.24 | 89.39 | 98.33 | |
HAL | 59.12 | 88.51 | - | - | 89.20 | 91.17 | 89.52 | 98.35 | |
DER | 83.81 | 95.43 | 36.73 | 69.50 | 91.66 | 94.14 | 94.90 | 99.29 | |
DER++ | 85.24 | 96.12 | 39.02 | 69.84 | 92.26 | 94.65 | 95.30 | 99.47 |
MNIST-360 - General Continual Learning | ||||||||
---|---|---|---|---|---|---|---|---|
JOINT | SGD | Buffer | ER | MER | A-GEM-R | GSS | DER | DER++ |
200 | 49.27 | 48.58 | 28.34 | 43.92 | 55.22 | 54.16 | ||
82.98 | 19.09 | 500 | 65.04 | 62.21 | 28.13 | 54.45 | 69.11 | 69.62 |
1000 | 75.18 | 70.91 | 29.21 | 63.84 | 75.97 | 76.03 |