danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about Plan2explore

TachikakaMin opened this issue · comments

For Plan2Explore, in expl.py the Class Plan2Explore will have a world model.

class Plan2Explore(common.Module):

  def __init__(self, config, act_space, wm, tfstep, reward):
    self.config = config
    self.reward = reward
    self.wm = wm

And this model will be WorldModel which is the same as dreamerv2.

class Agent(common.Module):

  def __init__(self, config, obs_space, act_space, step):
    self.config = config
    self.obs_space = obs_space
    self.act_space = act_space['action']
    self.step = step
    self.tfstep = tf.Variable(int(self.step), tf.int64)
    self.wm = WorldModel(config, obs_space, self.tfstep)
    self._task_behavior = ActorCritic(config, self.act_space, self.tfstep)
    if config.expl_behavior == 'greedy':
      self._expl_behavior = self._task_behavior
    else:
      self._expl_behavior = getattr(expl, config.expl_behavior)(
          self.config, self.act_space, self.wm, self.tfstep,
          lambda seq: self.wm.heads['reward'](seq['feat']).mode())

For worldmodel training, the code will encode all information include reward into encoder

def loss(self, data, state=None):
    data = self.preprocess(data)
    embed = self.encoder(data)
def preprocess(self, obs):
    dtype = prec.global_policy().compute_dtype
    obs = obs.copy()
    for key, value in obs.items():
      if key.startswith('log_'):
        continue
      if value.dtype == tf.int32:
        value = value.astype(dtype)
      if value.dtype == tf.uint8:
        value = value.astype(dtype) / 255.0 - 0.5
      obs[key] = value
    obs['reward'] = {
        'identity': tf.identity,
        'sign': tf.sign,
        'tanh': tf.tanh,
    }[self.config.clip_rewards](obs['reward'])
    obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype)
    obs['discount'] *= self.config.discount
    return obs
class Encoder(common.Module):
  def _cnn(self, data):
    x = tf.concat(list(data.values()), -1)

But Plan2explore says there should not be env reward.

The line of the encoder you're showing is only applied to image observations, not the reward. You can choose which observation keys should be used via these config options:

encoder.mlp_keys: '.*'
encoder.cnn_keys: '.*'
decoder.mlp_keys: '.*'
decoder.cnn_keys: '.*'

But only images (rank 3 tensor) and vectors (rank 1 tensor) are supported. The reward is scalar (rank 0 tensor).