xiuqhou / Salience-DETR

[CVPR 2024] Official implementation of the paper "Salience DETR: Enhancing Detection Transformer with Hierarchical Salience Filtering Refinement"

Home Page:https://arxiv.org/abs/2403.16131

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

backbone shape

yjdzyr opened this issue · comments

感谢作者提供的优秀项目,我有一个疑问,我想输出resnet 第一个卷积的shape,打印结果是Proxy(getattr_1),无法显示(B,C,H,W)的形状,请问这需要如何查看呢,非常期待作者的回复
9e4274436820f87af00cc39288d0fea

这个问题的原因是这里resnet并不是常规的nn.Module模块,而是经过转换后的torch.fx.Graph模块。

resnet原本是分类网络,输出的是分类结果而不是中间的特征图,我是使用了torchvision.models.feature_extraction.create_feature_extractor方法,在不更改代码的情况下获取resnet中间的四层特征图,这样它才能作为检测网络的backbone。但转换后的resnet是torch.fx.Graph模块,它每一层的x其实是torch.fx.proxy.Proxy示例而不是torch.Tensor,所以不能直接查看形状。

如果您想研究resnet每一层的输出,可以把resnet.py第420-425行注释掉,这样ResnetBackbone返回的就是转换前的nn.Module模型,您可以去查看每一层的形状。

        # # create feature extractor
        # return_layers = [f"layer{idx + 1}" for idx in return_indices]
        # resnet = create_feature_extractor(resnet, return_layers)
        # resnet.num_channels = [
        #     64 * model_config.block.expansion * 2**idx for idx in return_indices
        # ]
        # resnet.return_layers = return_layers

记得调试完后再把注释去掉,否则输出就和后面DETR检测头对应不上了

我这里的resnet代码是从torchvision复制过来的,只是稍微进行了修改,增加了对DCN和ResNeXt101_32x4d的支持。所以您也直接调试torchvision的resnet模型,输出都是一样的。

from torchvision.models.resnet import resnet50