luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RNNs hidden resets

esraaelelimy opened this issue · comments

In the rnn_ppo implementations, the rnn uses the done signal at time t to reset the hidden state, but shouldn't it use the done at {t-1} instead?
From my understanding, we reset the hidden states at the beginning of the episode, and to know if an observation o_t is the start of an episode, we should check done_{t-1}, not done_{t}?

actually, I think the implementation does that, but it wasn't clear at first. Looking at Gymnax environments' implementations, if the episode terminates, the returned observation is the start of the new episode, not the terminal observation. So, we get (Observation_0, done_T,...), not (Observation_T, done_T,..). Hence, using the current returned done signal makes sense when resetting the hidden states.