SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A question about the rpb in LegacyNeighborhoodAttention2D

lartpang opened this issue · comments

commented

My question

Why is the same relative position index used for several positions in the middle?

Information

def apply_pb(self, attn, height, width):
"""
RPB implementation by @qwopqwop200
https://github.com/qwopqwop200/Neighborhood-Attention-Transformer
"""
num_repeat_h = torch.ones(self.kernel_size,dtype=torch.long)
num_repeat_w = torch.ones(self.kernel_size,dtype=torch.long)
num_repeat_h[self.kernel_size//2] = height - (self.kernel_size-1)
num_repeat_w[self.kernel_size//2] = width - (self.kernel_size-1)
bias_hw = (self.idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*self.kernel_size-1)) + self.idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + self.idx_k
# Index flip
# Our RPB indexing in the kernel is in a different order, so we flip these indices to ensure weights match.
bias_idx = torch.flip(bias_idx.reshape(-1, self.kernel_size**2), [0])
return attn + self.rpb.flatten(1, 2)[:, bias_idx].reshape(self.num_heads, height * width, 1, self.kernel_size ** 2).transpose(0, 1)

A simple visualization:

rpb

The related code is copied from LegacyNeighborhoodAttention2D:

# %%
import matplotlib.pyplot as plt
import numpy as np
import torch

kernel_size = 3
height = width = 5
rpb_size = 2 * kernel_size - 1

# %%
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
shared_bg = np.zeros((height, width), dtype=np.uint8)

# %%
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)
print(idx_k.reshape(kernel_size, kernel_size))

# %%
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
bias_hw = (
    idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2 * kernel_size - 1)
) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2)
print(bias_idx)
'''
tensor([[ 0,  1,  2,  5,  6,  7, 10, 11, 12],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 2,  3,  4,  7,  8,  9, 12, 13, 14],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [10, 11, 12, 15, 16, 17, 20, 21, 22],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [12, 13, 14, 17, 18, 19, 22, 23, 24]])
'''

# %%
for h in range(height):
    for w in range(width):
        new_bg = shared_bg.flatten().copy()
        new_bg[bias_idx[h * height + w]] = 255
        new_bg = new_bg.reshape(height, width)
        axes[h, w].imshow(new_bg)

# %%
plt.show()

Hello and thank you for your interest,

Just to give a bit of background: RPB is in theory a continuous function, at least they way we intended it for NA/DiNA.
Here we're only learning a discrete set of weights because our kernel size is typically fixed.

As for its implementation here, most tokens (non-edge cases) share an identical RPB grid: north, south, east, west -- and positions in between, i.e. northwest.
And of course there's a magnitude: 1 north, 2 north, etc.

As a result, if you look at a visualization of NA, you would see that if we don't consider edge cases, the key-value positions for the rest of the feature map is identical: query is centered, and the neighbors are wrapped around it, hence they share the same RPB.

It becomes different for the edge cases precisely because they are not centered. For instance, the north-west (top-left) most pixel is always attending to an "extended neighborhood", which is explained in the original NAT paper, therefore its relative positional biases with respect to its key-value pair, or neighborhood, would be different compared to non-edge cases where they're always centered.

To clarify further, you can try plotting much larger inputs, in which you would see the RPB difference only in the corners and see an identical RPB index in the middle.
By the way, thank you for taking the time to plot these, I'm sure it'll help other users as well.

I hope this explains the idea, but if that's not the case, please let us know so we can clarify further.

commented

@alihassanijr

Thanks for your reply!

About the original question

In my original example, some settings were blocking my understanding. I optimized the code and it is more intuitive now.
But this also leads to another problem, see the discussion in the next section.

import matplotlib.pyplot as plt
import numpy as np
import torch

# specify the height and width of the feature map
height = width = 10

# construct a figure containing height*width subfigures corresponding to different (h,w) pixel
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
fig.suptitle('All Index Windows of RPE for each position of H-W Plane')

# specify the size of kernel for position bias
kernel_size = 5

# construct a shared relative position bias map
rpb_size = 2 * kernel_size - 1
shared_rpb_bg = np.zeros((rpb_size, rpb_size), dtype=np.uint8)

idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
# absolute 1D indices in the left-top window of the rpe map (2*kernel_size-1, 2*kernel_size-1)
# other window indices can be obtained by adding a new start index on this `idx_k`
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)

# construct indices of the window in rpe map for each (h,w) pixel
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
# the base h and w of the four edge regions is different from others
bias_hw = idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * rpb_size + idx_w.repeat_interleave(num_repeat_w)
# each (h,w) in the H-W plane corresponds to a window of kernel_size*kernel_size containing indices
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2) # height*width,kernel_size**2

# traverse all positions to visualize and highlight their own index window in the shared rpe map
for h in range(height):
    for w in range(width):
        new_rpb_bg = shared_rpb_bg.flatten().copy()

        new_start_idx = h * height + w
        new_rpb_bg[bias_idx[new_start_idx]] = 255  # index the specific window in rpb map
        new_rpb_bg = new_rpb_bg.reshape(rpb_size, rpb_size)
        axes[h, w].imshow(new_rpb_bg)
        axes[h, w].set_title(f"Win {(h,w)}")

plt.show()

rpb-k5-h10

About the relative position bias for NAT

Let's consider a simple case, kernel_size=3, and the rpb map is [2*3-1, 2*3-1]=[5,5].

The real indices of rpb map is:

(-2, -2), (-2, -1), (-2, 0), (-2, 1), (-2, 2), 
(-1, -2), (-1, -1), (-1, 0), (-1, 1), (-1, 2), 
(0, -2), (0, -1), (0, 0), (0, 1), (0, 2),
(1, -2), (1, -1), (1, 0), (1, 1), (1, 2),
(2, -2), (2, -1), (2, 0), (2, 1), (2, 2),

In traditional attention, the rpb is simple and it has no edge regions that require special consideration. So it only has one index pattern: (Fixed: In fact, the traditional RPB is more like a special form of NAT with only edge regions. So the implementation here is generic.)

In the convolution-like NAT operation, the rpe become more complicated.

The index window in the non-edge region is consist with the following pattern window:

(-1, -1), (-1, 0), (-1, 1), 
(0, -1), (0, 0), (0, 1), 
(1, -1), (1, 0), (1, 1),

The index window in the edge region is:

# start at the pixel (h=0, w=0), we denote the matrix as $W_{0,0}$
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);
(2, 0), (2, 1), (2, 2);

# start at the pixel (h=0, w=1), we denote the matrix as $W_{0,1}$
(0, -1), (0, 0), (0, 1);
(1, -1), (1, 0), (1, 1);
(2, -1), (2, 0), (2, 1);

# start at the pixel (h=1, w=0), we denote the matrix as $W_{1,0}$
(-1, 0), (-1, 1), (-1, 2);
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);

# ....

# start at the pixel (h=height-1, w=width-2), we denote the matrix as $W_{height-1,width-2}$
(-2, -1), (-2, 0), (-2, 1),
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 0), (0, 1),

# start at the pixel (h=height-1, w=width-1), we denote the matrix as $W_{height-1,width-1}$
(-2, -2), (-2, -1), (-2, 0), 
(-1, -2), (-1, -1), (-1, 0), 
(0, -2), (0, -1), (0, 0),

In current implementation of the rpb of LegacyNeighborhoodAttention2D, the index pattern does not correspond to the abovementioned real indices of rpb map.

Maybe it's the flip?

# Index flip
# Our RPB indexing in the kernel is in a different order, so we flip these indices to ensure weights match.
bias_idx = torch.flip(bias_idx.reshape(-1, self.kernel_size**2), [0])

We have this additional flip to make sure the behavior is identical to the behavior programmed in NATTEN.

commented

@alihassanijr

oh.... I understand it. Thank you so much for your patient reply.

I'm closing this issue now because we're moving our extension to its own separate repository, and due to inactivity.

Please feel free to reopen it if you still have questions, or open an issue in NATTEN if it's related to that.