ftgreat / stable-weight-decay-regularization

[NeurIPS 2023] The PyTorch Implementation of Scheduled (Stable) Weight Decay.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Scheduled(Stable)-Weight-Decay-Regularization

The PyTorch Implementation of Scheduled (Stable) Weight Decay.

The algorithms were first proposed in our arxiv paper.

A formal version with major revision and theoretical mechanism "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective" is accepted at NeurIPS 2023.

Why Scheduled (Stable) Weight Decay?

We proposed the Scheduled (Stable) Weight Decay (SWD) method to mitigate overlooked large-gradient-norm pitfalls of weight decay in modern deep learning libraries.

  • SWD can penalize the large gradient norms at the final phase of training.

  • SWD usually makes significant improvements over both L2 regularization and decoupled weight decay.

  • Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can usually outperform complex Adam variants, which have more hyperparameters.

The environment is as bellow:

Python 3.7.3

PyTorch >= 1.4.0

Usage

You may use it as a standard PyTorch optimizer.

import swd_optim

optimizer = swd_optim.AdamS(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

Test performance

Dataset Model AdamS SGD M Adam AMSGrad AdamW AdaBound Padam Yogi RAdam
CIFAR-10 ResNet18 4.910.04 5.010.03 6.530.03 6.160.18 5.080.07 5.650.08 5.120.04 5.870.12 6.010.10
VGG16 6.090.11 6.420.02 7.310.25 7.140.14 6.480.13 6.760.12 6.150.06 6.900.22 6.560.04
CIFAR-100 DenseNet121 20.520.26 19.810.33 25.110.15 24.430.09 21.550.14 22.690.15 21.100.23 22.150.36 22.270.22
GoogLeNet 21.050.18 21.210.29 26.120.33 25.530.17 21.290.17 23.180.31 21.820.17 24.240.16 22.230.15

Citing

If you use Scheduled (Stable) Weight Decay in your work, please cite "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective".

@inproceedings{xie2023onwd,
    title={On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective},
    author={Xie, Zeke and Xu, Zhiqiang and Zhang, Jingzhao and Sato, Issei and Sugiyama, Masashi},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023}
}

About

[NeurIPS 2023] The PyTorch Implementation of Scheduled (Stable) Weight Decay.

License:MIT License


Languages

Language:Python 69.4%Language:Jupyter Notebook 30.6%