ZJULearning / RMI

This is the code for the NeurIPS 2019 paper Region Mutual Information Loss for Semantic Segmentation.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sample weights for RMI loss

shkarupa-alex opened this issue · comments

Some augmentations (e.g. random angle rotation) make image and mask not fully significant.
To deal with such cases i usually use per pixel weights (0. for holes, 1. for correct parts) and multiply per pixel loss on that weights.

But RMI loss uses "high dimension points" and final loss has shape incompatible with original labels.

Could you please suggest what is the best way to decouple such "holes" loss (multiply by pixel weight)?

Well, RMI uses the Cov(Y, P) and Var(Y)/Var(P) to calculate the loss value, i.e., the statistics of the data.

I think the best way is to ignore the hole areas caused by rotation.
For example, if the hole areas are labeled by 255, we can ignore the corresponding points when calculating the normal cross entropy loss or ignore these points when constructing the "high dimension points". Then these meaningless points will not joint the procdure of loss calculating and gradient backprop.

It is clear how to ignore "holes" in BCE, but i can't express how to

ignore these points when constructing the "high dimension points"

That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits.
Also we cat slice sample weights in the same manner as labels.
As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectors

But i don't understand how to use them next: what and when should be multiplied by this weighs?
@mzhaoshuai , can you suggest any idea?

It is clear how to ignore "holes" in BCE, but i can't express how to

ignore these points when constructing the "high dimension points"

That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits.
Also we cat slice sample weights in the same manner as labels.
As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectors

But i don't understand how to use them next: what and when should be multiplied by this weighs?
@mzhaoshuai , can you suggest any idea?

I think you can reserve a mask, where the "holes" are 0 and other areas are 1,
then you can use this mask to select the meaningful points from the output of

def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True):
.

  1. You can construct high dimension mask points from the mask as the above function and select meaningful points when the corresponding high dimension mask point does not contain 0.
  2. You can do min_pooling of size radius*radius on the mask and use this mask to select the mingingful points. This means every high dimensional point which contains 0/"holes" should be ignored. Take care of the shapes and make the mask's shape (height and width) be the same as the output's shape.
    tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)]

Feel free to re-open this issue if you still have questions.