davda54 / sam

SAM: Sharpness-Aware Minimization (PyTorch)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using the step function with closure

mathuryash5 opened this issue · comments

Hello,

I am trying to use the step function(with the transformers and accelerate library) while passing the closure.

The step function has a decorator @torch.no_grad() and thus we specify enable_grad while calling the closure to compute gradients. How does the second call to closure() work? I have tried that and get the following error which sort of makes sense considering gradients will not be computed:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Here is the closure function I use:

def closure():
    tmp_ouput= model(**batch)
    tmp_loss = tmp_ouput.loss
    tmp_loss = tmp_loss / args.gradient_accumulation_steps
    accelerator.backward(tmp_loss)
    return accelerator 

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.