vballoli / nfnets-pytorch

NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/

Home Page:https://nfnets-pytorch.readthedocs.io/en/latest/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AGC without modifying the optimizer

kayuksel opened this issue · comments

Hello,

Is there a way to apply AGC externally without modifying the optimizer code?

I am using optimizers from torch_optimizer package and that would be good.

Yeah I think it should be possible by creating a wrapper for any optimizer. I'll try to add this as soon as possible, but for anyone who's interested in a quick implementation:

class AGC(optim.Optimizer):

    def __init__(self, optim, clipping=1e-2, eps=1e-3):
        super().__init__()
        self.optim = optim
        self.clipping = clipping
        self.eps = eps

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                param_norm = torch.max(unitwise_norm(
                    p), torch.tensor(group['eps']).to(p.device))
                grad_norm = unitwise_norm(p.grad)
                max_norm = param_norm * group['clipping']

                trigger = grad_norm > max_norm

                clipped_grad = p.grad * \
                    (max_norm / torch.max(grad_norm,
                                          torch.tensor(1e-6).to(grad_norm.device)))
                p.grad.data.copy_(torch.where(trigger, clipped_grad, p.grad))
        self.optim.step(closure)

Note that the above code is a raw implementation, but the actual code will be pretty close to this. Hope this helps @kayuksel

Added the generic AGC here 658675f, but needs testing. Do let me know how it works out for you.

@vballoli I have tested it. It worked without problems for me.