EoinKenny / AAAI-2021

Code for our paper

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to find input vector z for an image

ziweizhao1993 opened this issue · comments

Hi,

Thank you for sharing the official implementation! I see in the provided code that the input vector z is loaded from "misclassify_XX.pt", I wonder if the code to find z (equation 1 in the paper) is also available?

Ziwei

Hi Ziwei,

Thanks for the message, I have been looking and it appears my code for that function was lost when my computer wiped a few months ago.

It is one of the easiest functions to implement though. I haven't tested this, but something like this should work

def find_gan_approx(G, original_query_image, original_query_label, cnn):

	z = torch.rand.randn((1, latent_size))
	optimizer = optim.Adam([z], lr=lr)
	mse_loss = torch.nn.MSELoss()
	cce_loss = torch.nn.CrossEntroypy()
	target_label = torch.tensor([0,0,0,0,0,0,0,0,0,0], dtype=torch.float32)
	target_label[original_query_label] = 1

	lambda1 = 1.
	lambda2 = 1.

	for _ in range(10000):

		optimizer.zero_grad()

		pred_logits = cnn(G(z))
		current_img = G(z)

		loss1 = cce_loss(pred_logits, target_label) * lambda1
		loss2 = mse_loss(current_img, original_query_image) * lambda2
		loss = loss1 + loss2
                loss.backward()

		optimizer.step()

	return z

You'll have to tweak lambda1 and lambda2

Thank you!

Hi Eoin,

So sorry for bothering you again...

I just tested this function and it worked very well for most test images. However, I encountered a few test cases where this function failed to find the latent vector z. I was wondering if it is sensitive to the randn initialization? Should I tweak lambda1 and lambda2 for each image individually?

Thank you!

Best,
Ziwei

Hi Ziwei,

No problem at all, happy to help.

It is an open research problem how to do this best unfortunately. I remember I setup a "for loop" to do all 100+ digits I used in this paper and let it optimize for a while. In the end I manually inspected them and think I had to redo a few of them (maybe 5-20 I think). Yeah some do well with this function, some I found did better just using the MSE loss in the pixel space and ignoring the logits. I think using the logit loss does help avoid local minimum, but it can also cause problems unfortunately. You could also try an L1 Loss instead of MSE.

I would try to just use the pixel space loss if you're struggling (so set lambda1 = 0).

Sorry I can't help much more than that, if you figure out how to perfectly recover the latent z in a GAN in a reliable way, you will have a research paper waiting to be published :-)

I remember I tried to do this for ImageNet using BigGAN, and it really struggles, which is partly why I stuck to using MNIST and CIFAR-10 (even the latter struggles). There are smarter ways, like by training an encoder-decoder to help (e.g., see the AttGAN paper), but overall no one has a "full-proof" was to get z unfortunately.

Hi Eoin,

I followed your suggestions and found some papers on GAN inversion, it looks like a very interesting area of study.
Thank you for your help!

Best,
Ziwei