Implementation of the LAMB optimizer from the paper Reducing BERT Pre-Training Time from 3 Days to 76 Minutes.
Supports large batch training of upto 64k while only using the learning rate as a hyper parameter. Also supports smaller batch sizes without any change in other hyper parameters.
from keras_lamb import LAMBOptimizer
optimizer = LAMBOptimizer(0.001, weight_decay=0.01)
model.compile(optimizer, ...)
- Keras 2.2.4+
- Tensorflow 1.13+