Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides
By Antti Tarvainen, Harri Valpola (The Curious AI Company)
Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:
- Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
- At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
- Add an additional consistency cost between the student and teacher outputs (after softmax).
- Let the optimizer update the student weights normally.
- Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.
Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.
Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.
ImageNet using 10% of the labels | top-5 validation error |
---|---|
Variational Auto-Encoder [paper] | 35.42 ± 0.90 |
Mean Teacher ResNet-152 | 9.11 ± 0.12 |
All labels, state of the art [paper] | 3.79 |
CIFAR-10 using 4000 labels | test error |
---|---|
CT-GAN [paper] | 9.98 ± 0.21 |
Mean Teacher ResNet-26 | 6.28 ± 0.15 |
All labels, state of the art [paper] | 2.86 |
There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.
Regarding the results in the paper, the experiments using a traditional CNN architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.