hustvl / TopFormer

TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation, CVPR2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

关于ImageNet预训练分类模型的构建

itisianlee opened this issue · comments

class CLSHead(nn.Module):
    def __init__(self, in_ch=384, out_ch=1000, dropout=0.2):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.bn = nn.BatchNorm2d(in_ch)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        self.l = nn.Linear(in_ch, out_ch)

    def forward(self, x):
        x = self.bn(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.l(x)
        return x

分类的头是这样的写的吗?是直接接到:self.trans = BasicLayer的输出上吗?
我使用google drive上checkpoint加载评测ImageNet val精度是59.5左右。
期待回复,谢谢