danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Outdated dependencies and broken examples

Namnodorel opened this issue · comments

I wanted to try out dreamerv2 on our own environment (or at least the examples) but unfortunately run into some issues along the way.

The README example & Dockerfile use TensorFlow (and other library) versions that are outdated, in some cases pip doesn't even distribute the older versions anymore.

I attempted to run the minigrid example with a recent TensorFlow version.

The import from tensorflow.keras.mixed_precision import experimental as prec in nets.py causes an error as that API is no longer listed under experimental.

MiniGrid has migrated elsewhere and now uses Gymnasium: https://github.com/Farama-Foundation/Minigrid

Using the new minigrid env throws an error:

  ...
  File ".../dreamerv2/api.py", line 77, in train
    env = common.ResizeImage(env)
  File ".../dreamerv2/common/envs.py", line 447, in __init__
    self._keys = [
  File ".../dreamerv2/common/envs.py", line 449, in <listcomp>
    if len(v.shape) > 1 and v.shape[:2] != size]
TypeError: object of type 'NoneType' has no len()

Attempting to call train.py with our own environment (with an observation of an intensity image of size 42x30, contained in a NumPy array) manages to collect a prefill dataset but then fails with the error

File ".../dreamerv2/api.py", line 101, in train
    train_agent(next(dataset))
  File ".../dreamerv2/common/other.py", line 201, in __call__
    self._state, out = self._fn(*args, self._state)
  File ".../tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File ".../dreamerv2/agent.py", line 60, in train
    state, outputs, mets = self.wm.train(data, state)
  File ".../dreamerv2/agent.py", line 100, in train
    model_loss, state, outputs, metrics = self.loss(data, state)
  File ".../dreamerv2/agent.py", line 108, in loss
    post, prior = self.rssm.observe(
  File ".../dreamerv2/common/nets.py", line 50, in observe
    post, prior = common.static_scan(
  File ".../dreamerv2/common/other.py", line 41, in static_scan
    last = fn(last, inp)
  File ".../dreamerv2/common/nets.py", line 51, in <lambda>
    lambda prev, inputs: self.obs_step(prev[0], *inputs),
  File ".../dreamerv2/common/nets.py", line 96, in obs_step
    prior = self.img_step(prev_state, prev_action, sample)
  File ".../dreamerv2/common/nets.py", line 124, in img_step
    dist = self.get_dist(stats)
  File ".../dreamerv2/common/nets.py", line 81, in get_dist
    dist = tfd.Independent(common.OneHotDist(logit), 1)
  File ".../decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File ".../tensorflow_probability/python/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File ".../tensorflow_probability/python/distributions/independent.py", line 162, in __init__
    super(_Independent, self).__init__(
  File ".../tensorflow_probability/python/distributions/distribution.py", line 603, in __init__
    d for d in self._parameter_control_dependencies(is_init=True)
  File ".../tensorflow_probability/python/distributions/independent.py", line 337, in _parameter_control_dependencies
    raise ValueError('reinterpreted_batch_ndims({}) cannot exceed '
ValueError: reinterpreted_batch_ndims(1) cannot exceed distribution.batch_ndims(0)

I was unable to figure out what exactly the problem is here, and whether it is caused by the updated dependency versions, a problem with our environment or something else entirely.

Could you update the dependencies and examples to make it possible again to try out dreamerv2?

I found that the code is not compatible with tensorflow==2.6.0, and one alternative way is to use:

pip install tensorflow==2.4.2 tensorflow_probability==0.12.2

and make sure your python version <= 3.8.

Moreover be cautious about the installation of cuda, if your cuda version is >= 11.0, make sure that you create a soft link for libcusolver.so.10 at the path where your tensorflow lies, for instance for cuda 11.x.x:

ln -s /usr/local/your-cuda/lib64/libcusolver.so.11 ~/path-to-your-python/site-packages/tensorflow/libcusolver.so.10