open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark

Home Page:https://mmdetection.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reimplement a custom model with new modules implemented by yourself

BruceLxw opened this issue · comments

I want to add an attention module to ResNet50. How should I change my attention module (source code) and then add it to ResNet50 in mmdetection. Also, my previous models were all using "open- mmlab://detectron2/resnet50_caffe“,How to set the initialization of this attention module after adding its pre-training weights so that other resnet50 parts remain the same as the initialization model mentioned above? Thank you very much for your answer, thank you!

The specific issue is that I have added an attention module after the second convolution of BasicBlock and the third convolution of BottleNeck, while the rest are ResNet50 original versions of mmdetection. Therefore, if I want to continue using the pre-training model mentioned above in the original ResNet50 section, and the newly added attention module uses simple initialization methods such as Kaiming, how should I implement it? Thank you

#mmdet\models\backbones\mynet.py

class MyA(BaseModule):
    def __init__(self, channels, factor=32,
                 norm_cfg=None,
                 conv_cfg=None,
                 init_cfg= [
                    dict(type='Kaiming', layer='Conv2d'),
                    dict(
                        type='Constant',
                        val=1,
                        layer=['_BatchNorm', 'GroupNorm'])
                ],):
        super(MyA, self).__init__(init_cfg)
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        norm_cfg = dict(type='GN', num_groups=self.groups),
        self.conv1x1 = build_conv_layer(
            conv_cfg,
            channels // self.groups,
            channels // self.groups,
            1,
            stride=1,
            padding=0,
            dilation=1,
            bias=False)
        
        self.conv3x3 = build_conv_layer(
            conv_cfg,
            channels // self.groups,
            channels // self.groups,
            3,
            stride=1,
            padding=1,
            dilation=1,
            bias=False)

        self.norm1_name, norm1 = build_norm_layer(norm_cfg, channels // self.groups, postfix=1)
        self.add_module(self.norm1_name, norm1)
    @property
    def norm1(self):
        """nn.Module: normalization layer after the first convolution layer"""
        return getattr(self, self.norm1_name)
    
    def forward(self, x):
        b, c, h, w = x.size()

        group_x = x.reshape(b * self.groups, -1, h, w)
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)

        x1 = self.norm1(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)

        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)
        y1 = torch.matmul(x11, x12)

        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)
        y2 = torch.matmul(x21, x22)

        weights = (y1+y2).reshape(b * self.groups, 1, h, w) 
        weights_ =  weights.sigmoid()
        out = (group_x * weights_).reshape(b, c, h, w)
        return out

#...

class MyABasicBlock(BaseModule):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None,
                 plugins=None,
                 init_cfg=None):
        super(EMABasicBlock, self).__init__(init_cfg)
	    #...
        self.mya = MyA(channels=planes)
        
    #...
    
    def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.norm2(out)
            # 加入MyA
            out = self.mya(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out

#...
class MyABottleneck(BaseModule):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None,
                 plugins=None,
                 init_cfg=None):
        """Bottleneck block for ResNet.

        If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
        it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
        super(MyABottleneck, self).__init__(init_cfg)
    	#...
        self.mya = MyA(planes * self.expansion)
	
    #...
    
    def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x
            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv1_plugin_names)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv2_plugin_names)

            out = self.conv3(out)
            out = self.norm3(out)

            # 加入MyA
            out = self.mya(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv3_plugin_names)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out

I did not modify the rest of the ResNet Class. I have already passed in the cfg initialized by Kaiming from MyA, and now I still want to use my previous pre training configuration:

backbone=dict(
    norm_cfg=dict(requires_grad=False),
    norm_eval=True,
    style="caffe",
    init_cfg=dict(
        type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"
    ),
),

Is it okay for me to do this? Or how can we use this pre trained model on the original ResNet section, while initializing the MyA module with Kaiming and GN?