samet-akcay / skip-ganomaly

Source code for Skip-GANomaly paper

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 3, 4, 4]] is at version 2; expected vers

pankSM opened this issue · comments

In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!

I think that self.err_g_lat is appear only in generator,pls help me to correct my idea

def backward_g(self):
""" Backpropagate netg
"""
self.err_g_adv = self.opt.w_adv * self.l_adv(self.pred_fake, self.real_label)
self.err_g_con = self.opt.w_con * self.l_con(self.fake, self.input)
#self.err_g_lat = self.opt.w_lat * self.l_lat(self.feat_fake, self.feat_real)

    self.err_g = self.err_g_adv + self.err_g_con #+ self.err_g_lat
    self.err_g.backward(retain_graph=True)

def backward_d(self):
    # Fake
    pred_fake, _ = self.netd(self.fake.detach())
    self.err_d_fake = self.l_adv(pred_fake, self.fake_label)

    # Real
    # pred_real, feat_real = self.netd(self.input)
    self.err_d_real = self.l_adv(self.pred_real, self.real_label)

    # Combine losses.
    self.err_g_lat = self.opt.w_lat * self.l_lat(self.feat_fake, self.feat_real)
    self.err_d = self.err_d_real + self.err_d_fake + self.err_g_lat
    self.err_d.backward(retain_graph=True)

In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!

It seems that the version of torch or torchvision is wrong.

commented

In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!

It seems that the version of torch or torchvision is wrong.

I add "detach()" to "self.err_g_lat", and it run normaly.

it works:

with torch.no_grad():
            self.pred_fake, self.feat_fake = self.netd(self.fake)