Sample weights for RMI loss
shkarupa-alex opened this issue · comments
Some augmentations (e.g. random angle rotation) make image and mask not fully significant.
To deal with such cases i usually use per pixel weights (0. for holes, 1. for correct parts) and multiply per pixel loss on that weights.
But RMI loss uses "high dimension points" and final loss has shape incompatible with original labels.
Could you please suggest what is the best way to decouple such "holes" loss (multiply by pixel weight)?
Well, RMI uses the Cov(Y, P)
and Var(Y)/Var(P)
to calculate the loss value, i.e., the statistics of the data.
I think the best way is to ignore the hole areas caused by rotation.
For example, if the hole areas are labeled by 255, we can ignore the corresponding points when calculating the normal cross entropy loss or ignore these points when constructing the "high dimension points". Then these meaningless points will not joint the procdure of loss calculating and gradient backprop.
It is clear how to ignore "holes" in BCE, but i can't express how to
ignore these points when constructing the "high dimension points"
That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits.
Also we cat slice sample weights in the same manner as labels.
As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectors
But i don't understand how to use them next: what and when should be multiplied by this weighs?
@mzhaoshuai , can you suggest any idea?
It is clear how to ignore "holes" in BCE, but i can't express how to
ignore these points when constructing the "high dimension points"
That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits.
Also we cat slice sample weights in the same manner as labels.
As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectorsBut i don't understand how to use them next: what and when should be multiplied by this weighs?
@mzhaoshuai , can you suggest any idea?
I think you can reserve a mask, where the "holes" are 0 and other areas are 1,
then you can use this mask to select the meaningful points from the output of
Line 17 in d71897d
- You can construct high dimension mask points from the mask as the above function and select meaningful points when the corresponding high dimension mask point does not contain 0.
- You can do
min_pooling
of sizeradius*radius
on the mask and use this mask to select the mingingful points. This means every high dimensional point which contains 0/"holes" should be ignored. Take care of the shapes and make the mask's shape (height and width) be the same as the output's shape.
Line 24 in d71897d
Feel free to re-open this issue if you still have questions.