CoinCheung / pytorch-loss

label-smooth, amsoftmax, partial-fc, focal-loss, triplet-loss, lovasz-softmax. Maybe useful

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

请问am-softmax可以用在多分类任务上么,比如网络结构最后一层输出分类是7,那amsoftmax()中的参数填啥,第一个in_feats应该填7,第二个num_class也填7,这应该不对,不知道怎么使用,请教了

wbb95 opened this issue · comments

commented
请问am-softmax可以用在多分类任务上么,比如网络结构最后一层输出分类是7,那amsoftmax()中的参数填啥,第一个in_feats应该填7,第二个num_class也填7,这应该不对,不知道怎么使用,请教了

Hi,

An example is like this, you can have a try:

import torch
import torch.nn as nn
import torchvision.models as models
from pytorch_loss import AMSoftmax

class R50(nn.Module):

    def __init__(self, n_classes):
        super().__init__()
        model = models.resnet50(pretrained=False, num_classes=7)
        self.conv1 = model.conv1
        self.bn1 = model.bn1
        self.relu = model.relu
        self.maxpool = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.fc = nn.Identity()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = torch.mean(x, dim=(2, 3))
        x = self.fc(x)
        return x

model = R50(7)
crit = AMSoftmax(in_feats=2048, n_classes=7)
model.cuda()
crit.cuda()


### train your model
model.train()
for it in range(10):
    print('iter: ', it)
    inten = torch.randn(32, 3, 224, 224).cuda()
    lbs = torch.randint(0, 7, (32,)).cuda()

    logits = model(inten)
    loss = crit(logits, lbs)
    loss.backward()

### inference
# load fc layer here
with torch.no_grad():
    model.fc = nn.Linear(2048, 7, bias=False).cuda()
    model.fc.weight.copy_(crit.W.T)
model.eval()

inten = torch.randn(1, 3, 224, 224).cuda()
pred = model(inten).argmax(dim=1)
print(pred)

i am closing this, since this is not active anymore.