lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MAE image recon

chokyungjin opened this issue · comments

Thank you for your efforts, but I have a question about MAE code.

recon_loss = F.mse_loss(pred_pixel_values, masked_patches)

MSE loss was calculated between vectors other than the original image, but what code should I add to check if the recon works well through the output image?

@chokyungjin I don't totally understand your question, but to clarify the pred_pixel_values and masked_patches are both in pixel space from the original image. they have just been im2col per patch

I tried to reshape pred_pixel_values to the original image size again, but the shape is different.
The original image shape is 512, 512, 1 but pred_pixel_values shape is 1, 768, 256. Doesn't it was given the input value of ViT as whole original image?

It only contains the masked patches - to get back the whole image will take some more code. You'd need to unsort the masked and unmasked patches together, and then do your reshaping

I'm sorry, but can I request a pseudo-code?
Thanks.

I'm uncertain whether you're asking for code to compute loss over the full image, or to reconstruct the full image for viewing. If the latter, this is the code I ended up using to reconstruct the input image from the patches. @lucidrains feel free to fold this in if this is something you feel would be useful. The function ingests the raw patches and can return:

  1. The original image, prior to any masking or prediction
  2. The masked image (model input) by setting masked_indices
  3. The predicted image, where the masked patches are replaced by the predicted patch pixels, by setting both masked_indices and pred_pixel_values
    def reconstruct_image(self, patches, model_input, mean, std, masked_indices=None, pred_pixel_values=None, patch_size=8):
        """
        Reconstructs the image given patches. Can also reconstruct the masked image as well as the predicted image.
        To reconstruct the raw image from the patches, set masked_indices=None and pred_pixel_values=None. To reconstruct
        the masked image, set masked_indices= the masked_indices tensor created in the `forward` call. To reconstruct the
        predicted image, set masked_indices and pred_pixel_values = to their respective tensors created in the `forward` call.

        ARGS:
            patches (torch.Tensor): The raw patches (pre-patch embedding) generated for the given model input. Shape is
                (batch_size x num_patches x patch_size^2 * channels)
            model_input (torch.Tensor): The input images to the given model (batch_size x channels x height x width)
            mean (list[float]): An array representing the per-channel mean of the dataset used to
                denormalize the input and predicted pixels. (1 x channels)
            std (list[float]): An array representing the per-channel std of the dataset used to
                denormalize the input and predicted pixels. (1 x channels)
            masked_indices (torch.Tensor): The patch indices that are masked (batch_size x masking_ratio * num_patches)
            pred_pixel_values (torch.Tensor): The predicted pixel values for the patches that are masked (batch_size x masking_ratio * num_patches x patch_size^2 * channels)

        RETURN:
            reconstructed_image (torch.Tensor): Tensor containing the reconstructed image (batch_size x channels x height x width)
        """
        patches = patches.cpu()

        masked_indices_in = masked_indices is not None
        predicted_pixels_in = pred_pixel_values is not None

        if masked_indices_in:
            masked_indices = masked_indices.cpu()

        if predicted_pixels_in:
            pred_pixel_values = pred_pixel_values.cpu()

        patch_width = patch_height = patch_size
        reconstructed_image = patches.clone()

        if masked_indices_in or predicted_pixels_in:
            for i in range(reconstructed_image.shape[0]):
                if masked_indices_in and predicted_pixels_in:
                    reconstructed_image[i, masked_indices[i].cpu()] = pred_pixel_values[i, :].cpu().float()
                elif masked_indices_in:
                    reconstructed_image[i, masked_indices[i].cpu()] = 0

        invert_patch = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', w=int(model_input.shape[3] / patch_width),
                                 h=int(model_input.shape[2] / patch_height), c=model_input.shape[1],
                                 p1=patch_height, p2=patch_width)

        reconstructed_image = invert_patch(reconstructed_image)

        reconstructed_image = reconstructed_image.numpy().transpose(0, 2, 3, 1)
        reconstructed_image *= np.array(std)
        reconstructed_image += np.array(mean)

        return reconstructed_image.transpose(0, 3, 2, 1)

reconstructed_image = reconstructed_image.detach().numpy().transpose(0, 2, 3, 1) #bgr
reconstructed_image *= std
reconstructed_image += mean

Can these lines of code be omitted?

If you omit those lines, then I believe reconstructed_image will be close to zero-mean and unit standard-deviation (since the network is trained on normalized output), which isn't great for visualization. However, if you didn't train your model on normalized data, then you should be fine to remove those lines.

Ok, thank you very much, because if I add these, the output forecast image and mask image are all white, by the way, are STD and mean set to (1,3), and the output forecast image I trained for a while seems to have only pixels of color in the mask part

@kcetskcaz did you add the function to the MAE class? If so, how did you cal lit? If not, how did you implement it? I'm a bit confused over the whole thing.