StochBN
Experiments and StochBN implementation for pytorch.
Experiments
validation_exp.py
-- main experiment, comparison test accuracy with different BN strategybatch_avg.py
-- averaging test predictions through many batches with BN in training mode (mean and variance compute from batch).train_collected_stats.py
-- train network and during training switch BN layers to test mode (use collected mean and variance)
Results
Models:
HBN-T
-- our model with approximation of BatchNorm statistics tuned after network training (with fixed params)DE
-- DeepEnsembles https://arxiv.org/abs/1612.01474DO
-- binary Dropout
CIFAR5
Models trained on first 5 classes of CIFAR10. Entropy of predictive distribution estimated for the rest five classes (solid) and original ones (dashed).
ResNet18
VGG11
MNIST + notMNIST
LeNet5
Train on MNIST. Evaluate on MNIST (dashed) and notMNIST (solid)
See exps
for further details.
Acknowledgement
- Thanks https://github.com/kuangliu for models https://github.com/kuangliu/pytorch-cifar