davzha / multiset-equivariance

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Confused about set element presence masks

stablum opened this issue · comments

Hello,

I have been experimenting with DSPN models for some time. As they account for variable set size, a set element presence mask is being employed.

I'm a bit confused by iDSPN's code as I can see the mask missing from many components. Is iDSPN actually handling inputs with sets of varying size?

Hi Francesco,

Thanks for the question!

Indeed, as you point out both DSPN and iDSPN can handle sets of varying sizes by adding a presence variable. Code-wise it can simply be included by adding a feature to the data, i.e., concatenate a 0/1 to each element in the set. Both the 1st and 2nd experiments have fixed set sizes in each sub-experiment, so there's no need to include it there. In the 3rd experiment (CLEVR) the presence variable is concatenated in https://github.com/davzha/multiset-equivariance/blob/main/data.py#L154. I hope that clarifies it.

Hi David,

thank you so much for your very quick reply! Pardon me but things are still a bit unclear.
I see that the mask value is becoming part of the features of the items in every set.
But I don't see how the mask is going to be utilized in the loss function. In my understanding points in a set that will be outputted as having mask feature ≈ 0 shouldn't be considered as contributing to the loss, or should they?

They should contribute to the loss: the model has to learn to correctly predict the presence variable to be 0 for non-object elements and 1 otherwise. During training, the Hungarian matching assigns each prediction to a non-object or object element in the ground truth. It is possible to not compute a loss on the original features (i.e., excluding the presence variable) for an element that gets matched to a non-object element in the ground truth. This is what happens in DETR for example, but we follow the approach of DSPN/TSPN/Slot Attention. Note though that this differs from not including the predicted element in the loss computation when presence = 0 (or < threshold). This doesn't work, which can be seen from the situation when the model mistakenly predicts all elements' presence = 0.