OverLordGoldDragon / keras-adamw

Keras/TF implementation of AdamW, SGDW, NadamW, Warm Restarts, and Learning Rate multipliers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Warm Restart

PappaBu opened this issue · comments

Thank you for developing AdamW!

I have a question about warm restart. Is it necessary to force set "t_cur = 0" after the end of the training of each epoch? (var 1)
Or is 't_cur' automatically set to 0 after reaching 'total_iterations'? (var 2)

(var 1)
def on_epoch_end(...):
...
K.set_value(self.model.optimizer.t_cur, 0) # WARM RESTART
...

(var 2)
trainset_size = 1000
batch_size = 64
optimizer = AdamW(..., total_iterations=15, batch_size=batch_size)

And correct me if I'm wrong. Best practice for setting 'total_iterations' is:
total_iterations = int(trainset_size / batch_size)

Indeed, t_cur doesn't automatically reset itself - and you can find how it works exactly here: if training continues without resetting t_cur, then eta_t will simply continue to vary along the cosine curve (i.e. go back up). While this isn't how Warm Restarts (WR) are intended to work, you're still free to try it.

As for total_iterations, it's the number of iterations you expect throughout training; for example, if you expect 5 epochs, and each epoch has 100 batches, then 1 fit / batch --> total_iterations = 500.

If you use WR, however, authors recommend setting it to the number of expected iterations for a given restart - so if a restart is 2 epochs and each epoch has 100 batches, that's total_iterations = 200. However, as eta_t operates on both lr and weight decays, I found that using a separate variable, total_iterations_wd, to normalize weight decays can work better. (By default, total_iterations_wd = total_iterations)

Thank you so much for the detailed answer! Now everything is clear.