danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Why stop-grad on actor's input state in imagine() function ?

tominku opened this issue · comments

Hi,

While I'm taking a close look in the imagine() function in the world model,
I wonder why the gradient from the input feature to the actor should be stopped.

WorldModel's imagine fuction (agent.py)

def imagine(self, policy, start, is_terminal, horizon):
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
start['feat'] = self.rssm.get_feat(start)
start['action'] = tf.zeros_like(policy(start['feat']).mode())
seq = {k: [v] for k, v in start.items()}
for _ in range(horizon):
action = policy(tf.stop_gradient(seq['feat'][-1])).sample()

In my opinion, for the full gradient from the initial state to the last step of the sequence, shouldn't the 'feat' flow through the computation graph without the stop gradient? I just wonder why there is a stop gradient. have you tried the code without the stop gradient? What was the result like?

I'm struggling to find out the reason for the stop gradient and ask it here for help.
Thanks!

This just seemed like the safer choice. In practice it doesn't seem to make a difference last time I've tried a few years ago :)