JingyunLiang / SwinIR

SwinIR: Image Restoration Using Swin Transformer (official repository)

Home Page:https://arxiv.org/abs/2108.10257

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inquiry about patch embedding

zihao19 opened this issue · comments

I'm confused by the patch embedding codes in network_swinir.py
If I read the codes correctly, there seems to be no patch embedding in the forward method. Only flatten and transpose. Can someone help me understand the codes?

class PatchEmbed(nn.Module):

def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
    super().__init__()
    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)
    patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
    self.img_size = img_size
    self.patch_size = patch_size
    self.patches_resolution = patches_resolution
    self.num_patches = patches_resolution[0] * patches_resolution[1]

    self.in_chans = in_chans
    self.embed_dim = embed_dim

    if norm_layer is not None:
        self.norm = norm_layer(embed_dim)
    else:
        self.norm = None

def forward(self, x):
    x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C
    if self.norm is not None:
        x = self.norm(x)
    return x

def flops(self):
    flops = 0
    H, W = self.img_size
    if self.norm is not None:
        flops += H * W * self.embed_dim
    return flops`

because patchsize is 1 here

I have the same question, and the reply is missing.