davda54 / sam

SAM: Sharpness-Aware Minimization (PyTorch)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: stack expects a non-empty TensorList

pyl3000 opened this issue · comments

def _grad_norm(self):
    shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
    norm = torch.norm(
                torch.stack([
                    ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                    for group in self.param_groups for p in group["params"]
                    if p.grad is not None
                ]),
                p=2
           )
    return norm
commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.