qubvel / segmentation_models.pytorch

Segmentation models with pretrained backbones. PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Class weights specification on IoU metric

alepistola opened this issue · comments

I am trying to calculate weighted IoU but I do not understand how to pass class weights list to the function, thats my code:

tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

if stage == "train":
  self.log_dict(
      {
         "train/batch-IOU-img" : smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro-imagewise"),
         "train/batch-IOU" : smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")
      },
      prog_bar=True,
      batch_size=config["batch_size"]
  )
elif stage == "test":
  self.log_dict(
      {
         "test/IoU-weighted" : smp.metrics.iou_score(tp, fp, fn, tn, reduction='weighted', class_weights=[0.52, 16.67]),
         "test/mcc": self.mcc(pred_mask, mask)
      },
      on_step=False, on_epoch=True, prog_bar=True, batch_size=config["batch_size"]
  )

the problem is this function:
smp.metrics.iou_score(tp, fp, fn, tn, reduction='weighted', class_weights=[0.52, 16.67])

This is the error:
error