zijian-hu / SimPLE

Code for the paper: "SimPLE: Similar Pseudo Label Exploitation for Semi-Supervised Classification"

Home Page:https://arxiv.org/abs/2103.16725

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about computing pair loss

xxliang99 opened this issue · comments

Dear Zijian,

Thank you for your contribution! I got a question when reading the pair loss computing part.

I got confused about whether the similarity should be computed between pseudo labels generated from weak augmented data and prediction of strong augmented data, or, pseudo labels generated from weak augmented data among different input pictures.

If the latter one is right, does line 70 means that: if a weak-augmented-image(A)-generated pseudo label with high confidence is similar enough to the pseudo label from another input image B, then the distance between pseudo label A and prediction of strong augmented image B should be minimized?

distance_ij = self.get_distance_loss(loss_input, targets_i, dim=1, reduction='none')

Thank you for your time. I would appreciate it if I would be replied.

Best,
Vivian

Dear Vivian,

Thank you for your question.

I got confused about whether the similarity should be computed between pseudo labels generated from weak augmented data and prediction of strong augmented data, or, pseudo labels generated from weak augmented data among different input pictures.

The similarity should be calculated between 2 pseudo labels generated from weakly augmented data. We then apply similarity thresholds to filter out dissimilar pairs. The calculated similarity is only used for thresholding.
That's why we use sim: Tensor = self.get_similarity(targets_i, targets_j, dim=1).

If the latter one is right, does line 70 means that: if a weak-augmented-image(A)-generated pseudo label with high confidence is similar enough to the pseudo label from another input image B, then the distance between pseudo label A and prediction of strong augmented image B should be minimized?

The goal for the two thresholds in pair loss is to select high-quality anchor pseudo-labels. The definition of pair loss is that for two images x_l and x_r if their pseudo-labels q_l and q_r are similar and anchor pseudo-label q_l is of high confidence, we push the prediction of strongly augmented version of image x_r, pred(StrongAugment(x_r)) toward the anchor pseudo-label q_l.

By the definition of pair loss, our loss calculation should be distance_ij = self.get_distance_loss(loss_input, targets_i, dim=1, reduction='none')

Let me know if you need more clarifications or if you have additional questions.

Thank you,
Zijian Hu