dandelin / ViLT

Code for the ICML 2021 (long talk) paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about ITM pretraining

EagleW opened this issue · comments

Hi, @dandelin

I have some questions about ITM pre-training. For the pretraining ITM, how did you use itm loss and wpa loss? It seems that you use them separately:

value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
pl_module.log(
f"{loss_name}/{phase}/wpa_loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_wpa_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_wpa_loss").reset()

Why not simply add up those two losses and backpropagate them together?

ot_loss = (dist_pos.sum() - dist_neg.sum()) / (dist_pos.size(0) + dist_neg.size(0))
itm_logits = pl_module.itm_score(infer["cls_feats"])
itm_loss = F.cross_entropy(itm_logits, itm_labels.long())
ret = {
"itm_loss": itm_loss,
"itm_wpa_loss": 0.1 * ot_loss,
"itm_logits": itm_logits,
"itm_labels": itm_labels,
}
phase = "train" if pl_module.training else "val"
loss = getattr(pl_module, f"{phase}_itm_loss")(ret["itm_loss"])
wpa_loss = getattr(pl_module, f"{phase}_itm_wpa_loss")(ret["itm_wpa_loss"])
acc = getattr(pl_module, f"{phase}_itm_accuracy")(
ret["itm_logits"], ret["itm_labels"]
)
pl_module.log(f"itm/{phase}/loss", loss)
pl_module.log(f"itm/{phase}/wpa_loss", wpa_loss)
pl_module.log(f"itm/{phase}/accuracy", acc)

I also have the same question as #48

Thank you!