BruceW91 / External-Attention-pytorch

Pytorch implementation of various Attention Mechanism

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Contents


Attention Series


1. External Attention Usage

1.1. Paper

"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"

1.2. Overview

1.3. Code

from attention.ExternalAttention import ExternalAttention
import torch

input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)

2. Self Attention Usage

2.1. Paper

"Attention Is All You Need"

1.2. Overview

1.3. Code

from attention.SelfAttention import ScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)

3. Simplified Self Attention Usage

3.1. Paper

None

3.2. Overview

3.3. Code

from attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)

4. Squeeze-and-Excitation Attention Usage

4.1. Paper

"Squeeze-and-Excitation Networks"

4.2. Overview

4.3. Code

from attention.SEAttention import SEAttention
import torch

input=torch.randn(50,512,7,7)
se = SEAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)

5. SK Attention Usage

5.1. Paper

"Selective Kernel Networks"

5.2. Overview

5.3. Code

from attention.SKAttention import SKAttention
import torch

input=torch.randn(50,512,7,7)
se = SKAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)

6. CBAM Attention Usage

6.1. Paper

"CBAM: Convolutional Block Attention Module"

6.2. Overview

6.3. Code

from attention.CBAM import CBAMBlock
import torch

input=torch.randn(50,512,7,7)
kernel_size=input.shape[2]
cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)

7. BAM Attention Usage

7.1. Paper

"BAM: Bottleneck Attention Module"

7.2. Overview

7.3. Code

from attention.BAM import BAMBlock
import torch

input=torch.randn(50,512,7,7)
bam = BAMBlock(channel=512,reduction=16,dia_val=2)
output=bam(input)
print(output.shape)

8. ECA Attention Usage

8.1. Paper

"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks"

8.2. Overview

8.3. Code

from attention.ECAAttention import ECAAttention
import torch

input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)

9. DANet Attention Usage

9.1. Paper

"Dual Attention Network for Scene Segmentation"

9.2. Overview

9.3. Code

from attention.DANet import DAModule
import torch

input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)

10. Pyramid Split Attention Usage

10.1. Paper

"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network"

10.2. Overview

10.3. Code

from attention.PSA import PSA
import torch

input=torch.randn(50,512,7,7)
psa = PSA(channel=512,reduction=8)
output=psa(input)
print(output.shape)

11. Efficient Multi-Head Self-Attention Usage

11.1. Paper

"ResT: An Efficient Transformer for Visual Recognition"

11.2. Overview

11.3. Code

from attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,64,512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
output=emsa(input,input,input)
print(output.shape)
    

12. Shuffle Attention Usage

12.1. Paper

"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"

12.2. Overview

12.3. Code

from attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)

    

MLP Series

1. RepMLP Usage

1.1. Paper

"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition"

1.2. Overview

1.3. Code

from mlp.repmlp import RepMLP
import torch
from torch import nn

N=4 #batch size
C=512 #input dim
O=1024 #output dim
H=14 #image height
W=14 #image width
h=7 #patch height
w=7 #patch width
fc1_fc2_reduction=1 #reduction ratio
fc3_groups=8 # groups
repconv_kernels=[1,3,5,7] #kernel list
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
x=torch.randn(N,C,H,W)
repmlp.eval()
for module in repmlp.modules():
    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
        nn.init.uniform_(module.running_mean, 0, 0.1)
        nn.init.uniform_(module.running_var, 0, 0.1)
        nn.init.uniform_(module.weight, 0, 0.1)
        nn.init.uniform_(module.bias, 0, 0.1)

#training result
out=repmlp(x)
#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)

print(((deployout-out)**2).sum())

2. MLP-Mixer Usage

2.1. Paper

"MLP-Mixer: An all-MLP Architecture for Vision"

2.2. Overview

2.3. Code

from mlp.mlp_mixer import MlpMixer
import torch
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)

3. ResMLP Usage

3.1. Paper

"ResMLP: Feedforward networks for image classification with data-efficient training"

3.2. Overview

3.3. Code

from mlp.resmlp import ResMLP
import torch

input=torch.randn(50,3,14,14)
resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)
out=resmlp(input)
print(out.shape) #the last dimention is class_num

4. gMLP Usage

4.1. Paper

"Pay Attention to MLPs"

4.2. Overview

4.3. Code

from mlp.g_mlp import gMLP
import torch

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)

Re-Parameter Series


1. RepVGG Usage

1.1. Paper

"RepVGG: Making VGG-style ConvNets Great Again"

1.2. Overview

1.3. Code

from rep.repvgg import RepBlock
import torch


input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())

2. ACNet Usage

2.1. Paper

"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks"

2.2. Overview

2.3. Code

from rep.acnet import ACNet
import torch
from torch import nn

input=torch.randn(50,512,49,49)
acnet=ACNet(512,512)
acnet.eval()
out=acnet(input)
acnet._switch_to_deploy()
out2=acnet(input)
print('difference:')
print(((out2-out)**2).sum())

About

Pytorch implementation of various Attention Mechanism


Languages

Language:Python 100.0%