YuanGongND / ssast

Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Potential bug/sub-optimal implementation of final output activation and losses

harrygcoppock opened this issue · comments

Dear Yuan and authors,

Love the paper + repo. I am currently implementing SSAST as an ablation study in an audio classification paper which I am working on for the Alan Turing Institute and UK Gov. I have come across a potential small issue with the choice of combination of final activation function at inference time and the loss used at training time - please correct me if I am wrong.

Here we can see compatibility for crossentropy loss (softmax + negative log likelihood) and binary cross entropy loss (sigmoid + Binary cross entropy):

ssast/src/traintest.py

Lines 102 to 105 in 888bb6a

if args.loss == 'BCE':
loss_fn = nn.BCEWithLogitsLoss()
elif args.loss == 'CE':
loss_fn = nn.CrossEntropyLoss()

This is all fine and the training scheme is not an issue. However at inference time a sigmoid activation is applied over the output dims (#class) irrespective if crossentropy or binary cross entropy loss is the chosen loss at training time:

ssast/src/traintest.py

Lines 310 to 323 in 888bb6a

# compute output
audio_output = audio_model(audio_input, args.task)
audio_output = torch.sigmoid(audio_output)
predictions = audio_output.to('cpu').detach()
A_predictions.append(predictions)
A_targets.append(labels)
# compute the loss
labels = labels.to(device)
if isinstance(args.loss_fn, torch.nn.CrossEntropyLoss):
loss = args.loss_fn(audio_output, torch.argmax(labels.long(), axis=1))
else:
loss = args.loss_fn(audio_output, labels)

I have 2 (minor) issues with this.

  1. Cross entropy loss is for multiclass classification and therefore normalises the output (with softmax) to have a probability mass of 1 over all the classes, of which one is chosen as the prediction. However, at inference time, the outputs of the network are fed through a sigmoid activation, assuming a Bernoulli distribution over all classes. This means that probabilities over all classes at inference time do not sum to 1. In my eyes this is training as if the problem in multiclass but evaluating if the scheme is multilabel. I believe this will limit performance. This scheme is ofcourse fine when the loss is Binary cross entropy.
  2. At inference time when the loss is calculated a sigmoid activation is applied to the outputs before the outputs are passed to the loss function. This is effectively applying sigmoid then softmax for the cross entropy loss and sigmoid x2 for the binary cross entropy. Again as this is at inference time this will not affect the training of the model but will affect the reported interence time loss.

Do you agree with these points? If so a simple change e.g. should suffice:

# compute output
audio_output = audio_model(audio_input, args.task)
audio_output_for_loss = audio_output
if isinstance(args.loss_fn, torch.nn.CrossEntropyLoss):
    audio_output = torch.nn.functional.softmax(audio_output)
else:
    audio_output = torch.sigmoid(audio_output)
predictions = audio_output.to('cpu').detach()
A_predictions.append(predictions)
A_targets.append(labels)
# compute the loss
labels = labels.to(device)
 if isinstance(args.loss_fn, torch.nn.CrossEntropyLoss):
          loss = args.loss_fn(audio_output_for_loss, torch.argmax(labels.long(), axis=1))
else:
           loss = args.loss_fn(audio_output_for_loss, labels)

Many thanks in advance for your time!

Hi Harry,

Thanks so much for pointing this out. And the solution looks good.

I agree with you except "I believe this will limit performance." Since Sigmoid function is monotonically Increasing, it does not impact the mAP or accuracy calculation at

stats = calculate_stats(audio_output, target)
,so I think the only impact is the loss calculation, do you agree with me on this?

I will find a time to change this, but the current version should report everything correct except the validation loss.

Thanks!

Thanks for the speedy response. I agree with your point that it will not affect accuracy as the max prediction is taken. For our paper the main metrics are ROC_AUC and PR_AUC which can to some degree be viewed as the extent to which classes are separable in the classification space. Allowing the model to pic more than one class (through the sigmoid func) as apposed to picking just one (through softmax) might limit the degree to which the classes are separable, especially if only one can occur. In this current setting the model can give 100% probability of 2 mutually exclusive events occurring no? When really their probabilities should add up to 1.

E.g. if we look here the outputs are more separable after the softmax func as opposed to the sigmoid func.
Screenshot 2022-03-07 at 11 31 48

What are your thoughts on this or am I barking up the wrong tree aha