facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Code: Compatible to any channels for function patchify and unpatchify

zhongruiHuangDMRI opened this issue · comments

Dear author:
Thanks for your great work. The only suggestion I found is that, for some cases (like medical image), we use 1 channel image (gray) instead of colorful image (RGB). Here are the revised code (patchify and unpatchify function) for any channel image : (Written by @CH2-Carbene

def patchify_v2(self, imgs):
    """
    imgs: (N, self.img_channel, H, W)
    x: (N, L, patch_size**2 *self.img_channel)
    """
    p = self.patch_embed.patch_size[0]
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
    
    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], self.img_channel, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.img_channel))
    return x

def unpatchify_v2(self, x):
    """
    x: (N, L, patch_size**2 *self.img_channel)
    imgs: (N, self.img_channel, H, W)
    """
    p = self.patch_embed.patch_size[0]
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, self.img_channel))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], self.img_channel, h * p, h * p))
    return imgs