Kaixhin / PlaNet

Deep Planning Network: Control from pixels by latent planning with learned dynamics

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Potential error in the latent overshooting objective

jan1854 opened this issue · comments

Hi, I think there might be a small error in the implementation of the latent overshooting objective. In the following line the prior (overshooting_vars[4]) is passed as initial state to the transition model.

PlaNet/main.py

Line 187 in dacf418

beliefs, prior_states, prior_means, prior_std_devs = transition_model(torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1))

If I understand the code correctly, this prior state corresponds to st-d from equation (7) of the PlaNet paper. However, in the paper st-d is sampled from the posterior distribution q(st-d | o≤ t-d), not the prior.
The original implementation seems to use the posterior as initial state as well (see https://github.com/google-research/planet/blob/c04226b6db136f5269625378cd6a0aa875a92842/planet/tools/overshooting.py#L126-L134).

So I think the posterior rather than the prior should be used as initial state here.

I think you're right, thanks for spotting. I won't be able to work on this now, so would you be able to send in a PR that fixes this?

Yes, sure