Reconstruction looks terrible.
ArneNx opened this issue · comments
Hello,
I am trying to use the pre-trained MAGE model to do data-augmentation.
For this, I first want to test whether I can get reconstructions that are close to the original image.
My current attempt looks like this:
codebook_emb_dim = 256
codebook_size = 1024
codebook_size = self.mage.codebook_size
codebook_emb_dim = self.mage.vqgan.quantize.e_dim
batch_size = x.shape[0]
# tokenization
with torch.no_grad():
z_q, _, token_tuple = self.mage.vqgan.encode(x)
_, _, token_indices = token_tuple
token_indices = token_indices.reshape(z_q.size(0), -1)
token_drop_mask = torch.zeros(batch_size, token_indices.shape[1]).to(device)
token_all_mask = torch.zeros(batch_size, token_indices.shape[1]).to(device)
# concate class token
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.mage.fake_class_label
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
token_indices = token_indices.long()
# bert embedding
input_embeddings = self.mage.token_emb(token_indices)
# encoder
x = input_embeddings
for blk in self.mage.blocks:
x = blk(x)
latent = self.mage.norm(x)
# decoder
logits = self.mage.forward_decoder(latent, token_drop_mask, token_all_mask)
logits = logits[:, 1:, :codebook_size]
# get token prediction
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
sampled_ids = sample_dist.sample()
# vqgan visualization
z_q = self.mage.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(batch_size, 16, 16, codebook_emb_dim))
images = self.mage.vqgan.decode(z_q)
The resulting reconstructions look terrible. Sometimes I can see parts of the original image, but most of the time they are simply unrecognizable:
(original image:)
Judging from figure 2 in the paper, I would have assumed that it's possible to get good reconstructions even with a single iteration. Am I doing something wrong?
Also, do I see it correctly that the input data is not standardized beforehand?
For reconstruction without masking, you don't need the MAGE model. You can simply use the VQGAN tokenizer and detokenizer for this.
The input data to the VQGAN is of range [0, 1].
Thanks for the quick response.
That's a good point. Encoding and decoding with VQGAN alone gives me excellent reconstructions.
I'm just trying to do this with MAGE as well since I do want to introduce a small amount of masking later (to get variation into the reconstruction).
It's a bit surprising to me that this doesn't work.
One thing to notice is that MAGE (and MAE) does not have a reconstruction loss on unmasked tokens. This means that those unmasked positions are not forced to be identical with the input. In our experiments, we only predict the masked positions, and those unmasked positions we simply copy the input tokens.
Ah. That makes sense. With a higher mask ratio + copying over the tokens that were not masked, I get something close to the original image now.
Thanks a lot for your quick help and for the well-written code!
啊。这就说得通了。通过更高的遮罩率 + 复制未遮罩的标记,我现在得到了接近原始图像的图像。
非常感谢您的快速帮助和编写良好的代码!
hi, I meet the same problem, could you share your code? Thanks.
This basically reproduces figure 2 in the paper:
codebook_emb_dim = 256
codebook_size = 1024
codebook_size = self.mage.codebook_size
codebook_emb_dim = self.mage.vqgan.quantize.e_dim
batch_size = x.shape[0]
mask_rate = 0.4
# tokenization
with torch.no_grad():
z_q, _, token_tuple = self.mage.vqgan.encode(x)
orig_z_q = z_q.clone()
_, _, token_indices = token_tuple
token_indices = token_indices.reshape(z_q.size(0), -1)
seq_len = token_indices.shape[1]
token_drop_mask = torch.zeros(batch_size, seq_len).to(device)
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
while True:
noise = torch.rand(batch_size, seq_len, device=x.device) # noise in [0, 1]
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
token_all_mask = (noise <= cutoff_mask).float()
if token_all_mask.sum() == batch_size*num_masked_tokens:
break
else:
print("Rerandom the noise!")
token_indices[token_all_mask.nonzero(as_tuple=True)] = self.mage.mask_token_label
# concate class token
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.mage.fake_class_label
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
token_indices = token_indices.long()
# bert embedding
input_embeddings = self.mage.token_emb(token_indices)
# encoder
x = input_embeddings
for blk in self.mage.blocks:
x = blk(x)
latent = self.mage.norm(x)
# decoder
logits = self.mage.forward_decoder(latent, token_drop_mask, token_all_mask)
logits = logits[:, 1:, :codebook_size]
# get token prediction
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
sampled_ids = sample_dist.sample()
# vqgan visualization
z_q = self.mage.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(batch_size, 16, 16, codebook_emb_dim))
mask = (token_all_mask[:, 1:] == 1).reshape(-1, 1, 16, 16).repeat(1,codebook_emb_dim, 1,1)
z_q = torch.where(mask, z_q, orig_z_q)
images = self.mage.vqgan.decode(z_q)