Warm Restart
PappaBu opened this issue · comments
Thank you for developing AdamW!
I have a question about warm restart. Is it necessary to force set "t_cur = 0" after the end of the training of each epoch? (var 1)
Or is 't_cur' automatically set to 0 after reaching 'total_iterations'? (var 2)
(var 1)
def on_epoch_end(...):
...
K.set_value(self.model.optimizer.t_cur, 0) # WARM RESTART
...
(var 2)
trainset_size = 1000
batch_size = 64
optimizer = AdamW(..., total_iterations=15, batch_size=batch_size)
And correct me if I'm wrong. Best practice for setting 'total_iterations' is:
total_iterations = int(trainset_size / batch_size)
Indeed, t_cur
doesn't automatically reset itself - and you can find how it works exactly here: if training continues without resetting t_cur
, then eta_t
will simply continue to vary along the cosine curve (i.e. go back up). While this isn't how Warm Restarts (WR) are intended to work, you're still free to try it.
As for total_iterations
, it's the number of iterations you expect throughout training; for example, if you expect 5 epochs, and each epoch has 100 batches, then 1 fit / batch --> total_iterations = 500
.
If you use WR, however, authors recommend setting it to the number of expected iterations for a given restart - so if a restart is 2 epochs and each epoch has 100 batches, that's total_iterations = 200
. However, as eta_t
operates on both lr
and weight decays, I found that using a separate variable, total_iterations_wd
, to normalize weight decays can work better. (By default, total_iterations_wd = total_iterations
)
Thank you so much for the detailed answer! Now everything is clear.