LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reconstruction looks terrible.

ArneNx opened this issue · comments

Hello,

I am trying to use the pre-trained MAGE model to do data-augmentation.
For this, I first want to test whether I can get reconstructions that are close to the original image.

My current attempt looks like this:

codebook_emb_dim = 256
codebook_size = 1024
codebook_size = self.mage.codebook_size
codebook_emb_dim = self.mage.vqgan.quantize.e_dim
batch_size = x.shape[0]

# tokenization
with torch.no_grad():
    z_q, _, token_tuple = self.mage.vqgan.encode(x)

_, _, token_indices = token_tuple
token_indices = token_indices.reshape(z_q.size(0), -1)
token_drop_mask = torch.zeros(batch_size, token_indices.shape[1]).to(device)
token_all_mask = torch.zeros(batch_size, token_indices.shape[1]).to(device)

# concate class token
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.mage.fake_class_label
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
token_indices = token_indices.long()

# bert embedding
input_embeddings = self.mage.token_emb(token_indices)

# encoder
x = input_embeddings
for blk in self.mage.blocks:
    x = blk(x)
latent = self.mage.norm(x)

# decoder
logits = self.mage.forward_decoder(latent, token_drop_mask, token_all_mask)
logits = logits[:, 1:, :codebook_size]

# get token prediction
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
sampled_ids = sample_dist.sample()


# vqgan visualization
z_q = self.mage.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(batch_size, 16, 16, codebook_emb_dim))
images = self.mage.vqgan.decode(z_q)

The resulting reconstructions look terrible. Sometimes I can see parts of the original image, but most of the time they are simply unrecognizable:
image
(original image:)
image

Judging from figure 2 in the paper, I would have assumed that it's possible to get good reconstructions even with a single iteration. Am I doing something wrong?

Also, do I see it correctly that the input data is not standardized beforehand?

For reconstruction without masking, you don't need the MAGE model. You can simply use the VQGAN tokenizer and detokenizer for this.

The input data to the VQGAN is of range [0, 1].

Thanks for the quick response.
That's a good point. Encoding and decoding with VQGAN alone gives me excellent reconstructions.
I'm just trying to do this with MAGE as well since I do want to introduce a small amount of masking later (to get variation into the reconstruction).
It's a bit surprising to me that this doesn't work.

One thing to notice is that MAGE (and MAE) does not have a reconstruction loss on unmasked tokens. This means that those unmasked positions are not forced to be identical with the input. In our experiments, we only predict the masked positions, and those unmasked positions we simply copy the input tokens.

Ah. That makes sense. With a higher mask ratio + copying over the tokens that were not masked, I get something close to the original image now.

Thanks a lot for your quick help and for the well-written code!

commented

啊。这就说得通了。通过更高的遮罩率 + 复制未遮罩的标记,我现在得到了接近原始图像的图像。

非常感谢您的快速帮助和编写良好的代码!

hi, I meet the same problem, could you share your code? Thanks.

This basically reproduces figure 2 in the paper:

        codebook_emb_dim = 256
        codebook_size = 1024
        codebook_size = self.mage.codebook_size
        codebook_emb_dim = self.mage.vqgan.quantize.e_dim
        batch_size = x.shape[0]
        mask_rate = 0.4
        
        # tokenization
        with torch.no_grad():
            z_q, _, token_tuple = self.mage.vqgan.encode(x)

        orig_z_q = z_q.clone()

        _, _, token_indices = token_tuple
        token_indices = token_indices.reshape(z_q.size(0), -1)
        seq_len = token_indices.shape[1]
        token_drop_mask = torch.zeros(batch_size, seq_len).to(device)
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))

         # it is possible that two elements of the noise is the same, so do a while loop to avoid it
        while True:
            noise = torch.rand(batch_size, seq_len, device=x.device)  # noise in [0, 1]
            sorted_noise, _ = torch.sort(noise, dim=1)  # ascend: small is remove, large is keep
            cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
            token_all_mask = (noise <= cutoff_mask).float()
            if token_all_mask.sum() == batch_size*num_masked_tokens:
                break
            else:
                print("Rerandom the noise!")
        token_indices[token_all_mask.nonzero(as_tuple=True)] = self.mage.mask_token_label

        # concate class token
        token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = self.mage.fake_class_label
        token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
        token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
        token_indices = token_indices.long()

        # bert embedding
        input_embeddings = self.mage.token_emb(token_indices)

        # encoder
        x = input_embeddings
        for blk in self.mage.blocks:
            x = blk(x)
        latent = self.mage.norm(x)

        # decoder
        logits = self.mage.forward_decoder(latent, token_drop_mask, token_all_mask)
        logits = logits[:, 1:, :codebook_size]

        # get token prediction
        sample_dist = torch.distributions.categorical.Categorical(logits=logits)
        sampled_ids = sample_dist.sample()


        # vqgan visualization
        z_q = self.mage.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(batch_size, 16, 16, codebook_emb_dim))
        mask = (token_all_mask[:, 1:] == 1).reshape(-1, 1, 16, 16).repeat(1,codebook_emb_dim, 1,1)
        z_q = torch.where(mask, z_q, orig_z_q)
        images = self.mage.vqgan.decode(z_q)