OverLordGoldDragon / keras-adamw

Keras/TF implementation of AdamW, SGDW, NadamW, Warm Restarts, and Learning Rate multipliers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WeightDecay is incorrectly normalized

MHStadler opened this issue · comments

Hey,

first of all thank you for this library, it's great and works great in general

I just wanted to point out, that I think the weight decay is wrongly normalized based on the batch size - from the original paper, the normalized weight decay formula is as follows:

λ = λ_norm * sqrt(b / BT), where b is the batch size, B is the number of epochs, and most importantly T is total number of training samples
The code assumes that total_iterations is set to be equal to BT, however the iterations are counted based on the number of batch updates, which is equal to step_size * epochs

This is missing the total number of samples in each batch b, to get back to the orginally used BT: sqrt (b / b * total_iterations ) = sqrt (1 / total_iterations )
.
Therefore, the batch_size should be set to 1, if setting total_iterations or total_iterations_wd as described in the examples here.

Glad you found it useful.

Normalization works correctly. Quoting, "B is the total number of training points and T is the total number of epochs"; "training points" is a less common phrasing for what appears to be train iterations, or number of batch fits. The wording on B can be misleading, however; "number of training points" is same as "number of train iterations per epoch", thus times epoch yields the total. Further, the user is given more flexibility as there may be reasons to set total_iterations different from integer multiples of iters/epoch.

If B were number of samples per batch (i.e. b), that defeats the purpose of b and the justification for it: "Li et al. (2017) demonstrated that a smaller batch size (for the same total number of epochs) leads to the shrinking effect of weight decay being more pronounced". I.e., smaller b -> greater wd. To counteract this, we make λ smaller for smaller b and vice versa, which is what * sqrt(b) does.

Also unclear on your equation for "number of batch updates"; step_size (learning rate) is unrelated.

Hi,

I specifically reached out to the original authors about this:

"Let's say I have 512 samples, and a batch size of 32 - so for one pass through the whole dataset I would need 16 steps (16 batches of 32 each)
In this case, would I set B to 16, or 512? Given that b is 32 (the batch size), and T is the number of epochs (e.g.: 20) [Quoted from me to them]"

[Their Reply:] In your case B=512 as the overall number of samples in your dataset, b=32 samples per batch and T=20 epochs. Thus,
\lambda = \lambda_norm * sqrt(32/(512*20)) = \lambda_norm * 0.056 "

In the conversation above, you can see that BT = 512 * 20 - which is not the same as the value of total_iterations, which, at least according to the examples used in the start, are set to T * number_of_steps (e.g.: the wr restart example has 3 warm restarts epochs, where each epoch does batch updates -> total_iterations is set to 24)

Based on the above example, following the examples introduced, total iterations would be set to 16 * 20 (steps * epochs)

From the original paper: "Thus, λnorm can be interpreted as the weight decay used if only one batch pass is allowed" - if only one batch pass would be applied, the full decay would have to applied. Instead, a fraction of this decay is applied per batch pass - or that's my understanding at least

Sorry, in case this was confusing - step_size does not mean the learning_rate, I meant the number of steps (batch passes) per epoch

Hmm... you appear to be correct.

I've overlooked the fact that iterations are directly computed in terms of batch size, so batch_size/total_iterations is really batch_size / (total_samples / batch_size), so if b doubles then total_iterations halves, so the net-effect is a square which under the root becomes a linear scaling: batch_size * sqrt(1 / (total_iterations_per_epoch * epochs)). I agree this does substantially differ from a sqrt scheme. Further, one less parameter to set (batch_size) is a nice bonus.

Authors should've used better wording than "training points", and there's another point of ambiguity: B refers to number of samples fit per epoch, not samples in dataset (e.g. we can have 60 samples with batch_size 32, so we oversample 4; then B should be 64, not 60). This should've been clarified in paper, and the simplified equivalent presented.

The good news is, anyone who didn't have to change batch_size will be unaffected.

I'll make appropriate changes sometime soon. Thanks for raising this Issue.

Fixed, available in PyPi. Let me know if any further concerns.

Thank you for the quick fix - I totally agree, the notation is very confusing. I also found it strange that such a big part of the implementation was reduced to the appendix

Anyway, thanks again. Also great job on the new updates - not having to manually update the decays per layer is a huge QoL improvement

Their wording of this in the paper has by far been the most confusing thing in my life thus far and I finally find the answer in this thread. Thanks @MHStadler for reaching out to the authors and @OverLordGoldDragon for your clear and concise implementations.

@evanatyourservice Glad you were helped. One thing: I did not have fun implementing this in TF/Keras. If you ever plan on implementing custom functionality, I suggest ditching them for PyTorch; this'd take 1/10th the work there.

Looking at PyTorch's AdamW, it naively decays all weights, so I might implement layerwise if it wasn't done yet - but it's not even hard to just hand-code the logic each time.

Yeah I ditched tf for pytorch a while ago and my productivity instantly went up lol, it's much more easy to work with. I'm actually customizing novograd a bit more which already has the layer-wise functionality, but i wanted to make sure i was regularizing and decaying the weight decay correctly which the novograd authors didn't experiment with.