rosinality / swapping-autoencoder-pytorch

Unofficial implementation of Swapping Autoencoder for Deep Image Manipulation (https://arxiv.org/abs/2007.00653) in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: No grad accumulator for a saved leaf!

greeneggsandyaml opened this issue · comments

Hello, thank you for your work on this implementation (and your stylegen work), it looks great!

I'm getting a strange error:

Traceback (most recent call last):
  File "train.py", line 486, in <module>
    device,
  File "train.py", line 251, in train
    (recon_loss + g_loss + g_cooccur_loss).backward()
  File "/home/luke/.miniconda3/envs/new/lib/python3.7/site-packages/torch/tensor.py", line 227, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/luke/.miniconda3/envs/new/lib/python3.7/site-packages/torch/autograd/__init__.py", line 138, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: No grad accumulator for a saved leaf!

For context, I have not modified any of the code in the repo. I get this when I run CUDA_VISIBLE_DEVICES=4 python train.py --size 256 MY_DATA_DIR, with or without distributed training.

I have tried calling .backward() on each of the 3 losses independently, and they all give the same error.

My environment:

Python 3.7.6
PyTorch 1.8.0a0

Do you know how I might go about fixing this?

Much appreciated,
greeneggsandyaml

Seems like that recent version of pytorch changed some autograd mechanisms. But hard to know why currently as it occurs after discriminator backwards, that is, it is not problematic only in the generator updates.

Hmm, strange. I've also tried on PyTorch 1.7.0 and it's not working there either.

I try and find this problem relate to real_img.requires_grad = True and real_patch.requires_grad = True.
Comment out will face another error...
In the end, I downgrade my PyTorch to 1.5.0 or 1.4.0, both work fine.

I solved this problem by add detach() on version 1.8.0.
recon_loss = F.l1_loss(fake_img1, real_img1.detach())