RNaD: Possible Error in calculation of Neurd Loss
spktrm opened this issue · comments
In this line of the RNaD algorithm
Should the line instead be this? This is so we only subtract the mean calculated from the valid logits.
logits = logit_pi - (jnp.sum(
logit_pi * legal_actions, axis=-1, keepdims=True) / jnp.sum(legal_actions, axis=-1, keepdims=True))
As a result, should the line below be an average over actions rather than a sum?
i.e.
nerd_loss = jnp.sum(
legal_actions *
apply_force_with_threshold(logits, adv_pi, threshold, threshold_center),
axis=-1) / jnp.sum(legal_actions, axis=-1)
This is particularly relevant in games where there is frequently a number of invalid actions.
Hi @spktrm , I spoke to Julien.
He said you're correct about the first one, can you submit a PR?
The second one could go either way: it's just a matter of knowing what works. It is not clear whether one works better than the other and it might end up being similar behavior but require different hyper-parameters. Maybe you can try it and let us know?
I have submitted a PR regarding the first point here: #1157, thank you for the opportunity to contribute :).
With regards to the second point, I will experiment further with the fix I am suggesting and let you know how it goes.
Meanwhile, is it possible to provide clarity on these other issues? Namely:
Hi @spktrm,
Yeah I will make Julien aware of those (sorry, I thought they were resolved already).
I think it may be useful to also try contacting him directly by email, though... because I'm mostly just relaying messages from here to him and back :)
Thank you. What is his best email?
Thank you. What is his best email?
Still the same one from the Mastering Stratego paper.