Code: Compatible to any channels for function patchify and unpatchify
zhongruiHuangDMRI opened this issue · comments
zhongruiHuang commented
Dear author:
Thanks for your great work. The only suggestion I found is that, for some cases (like medical image), we use 1 channel image (gray) instead of colorful image (RGB). Here are the revised code (patchify and unpatchify function) for any channel image : (Written by @CH2-Carbene)
def patchify_v2(self, imgs):
"""
imgs: (N, self.img_channel, H, W)
x: (N, L, patch_size**2 *self.img_channel)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], self.img_channel, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.img_channel))
return x
def unpatchify_v2(self, x):
"""
x: (N, L, patch_size**2 *self.img_channel)
imgs: (N, self.img_channel, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.img_channel))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], self.img_channel, h * p, h * p))
return imgs