SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is DiNAT code is runnable?

chhkang opened this issue · comments

Hi, thanks for contribution, I'm very appreciate on your work.
I try to use the DiNAT_s code, and in below forward function, there seems some problem.

def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        outs = []
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, x = layer(x) ###### x_out is 4D tensor, x is 3D tensor
            if i in self.out_indices:
                norm_layer = getattr(self, f"norm{i}")
                x_out = norm_layer(x_out)

                out = x_out.permute(0, 3, 1, 2).contiguous()
                outs.append(out)

        return tuple(outs)

All weights are loaded properly with given 'dinat_s.pth' to designed (+some revision) model, but the tensor shape is odd. After tensor pass the first layer(i=0), x becomes 3D tensor. But it should be 4D tensor, so that it can be used on next layer(I=1). I debug this and on the End of the layer#0, it contains PatchMerging which makes output tensor shape [B, H/2W/2, 4C].

Think I'm in stuck and need help. Is model design is somehow change, or am I doing something wrong?

Hello and thank you for your interest.

Thank you for bringing this to our attention.

The issue

The reason for this issue is that DiNAT_s's implementation is slightly different than that of DiNAT and NAT because it was intended to be exactly like Swin Transformer in architecture, and only have its attention modules replaced.
Swin's original implementation used the 3D shape ([B, HW, C]), and reshaped back into 4D ([B, H, W, C]) when required. We avoid this in NAT/DiNAT, since every operation and layer in the model either expects that format, or can handle it without needing to be reshaped, with the exception of convolutions, but those are very rare compared to everything else (3 convs need the reshape and permute + initial layer.)

Apparently the PatchMerge implementation in classification was correct in reshaping back to 4D, but not in detection and segmentation, which is what you're referencing since the snippet has out_indices.

Solution

We're working on a solution that should be tested and merged soon.

The PR was merged. Please let us know the issue persists.