danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ValueError: . Tensor must have rank 4. Received rank 3, shape (208, 64, 64)

ipsec opened this issue · comments

Hi,

Congrats by excellent project.

I'm using a custom gym environment (images with shape [21, 4], the channel dimension is absent because its a grayscale image) and I'm getting this error.

Episode has 207 steps and return -19.3.
[2679] return -19.31 / length 207 / total_steps 2679 / total_episodes 13 / loaded_steps 2691 / loaded_episodes 13
Traceback (most recent call last):
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/check_ops.py", line 1231, in assert_rank
    assert_op = _assert_rank_condition(x, rank, static_condition,
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/check_ops.py", line 1131, in _assert_rank_condition
    raise ValueError(
ValueError: ('Static rank condition failed', 3, 4)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/Fernando/dev/dreamerv2/examples/test.py", line 16, in <module>
    dv2.train(env, config)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/api.py", line 94, in train
    driver(random_agent, steps=prefill, episodes=1)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/common/driver.py", line 57, in __call__
    [fn(ep, **self._kwargs) for fn in self._on_episodes]
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/common/driver.py", line 57, in <listcomp>
    [fn(ep, **self._kwargs) for fn in self._on_episodes]
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/api.py", line 74, in per_episode
    logger.write()
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/common/logger.py", line 44, in write
    output(self._metrics)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/dreamerv2/common/logger.py", line 117, in __call__
    tf.summary.image(name, value, step)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorboard/plugins/image/summary_v2.py", line 140, in image
    return tf.summary.write(
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 762, in write
    op = smart_cond.smart_cond(
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/framework/smart_cond.py", line 56, in smart_cond
    return true_fn()
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 750, in record
    summary_tensor = tensor() if callable(tensor) else array_ops.identity(
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorboard/util/lazy_tensor_creator.py", line 66, in __call__
    self._tensor = self._tensor_callable()
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorboard/plugins/image/summary_v2.py", line 112, in lazy_tensor
    tf.debugging.assert_rank(data, 4)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 206, in wrapper
    return target(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/check_ops.py", line 1178, in assert_rank_v2
    return assert_rank(x=x, rank=rank, message=message, name=name)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 206, in wrapper
    return target(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/ml/lib/python3.9/site-packages/tensorflow/python/ops/check_ops.py", line 1236, in assert_rank
    raise ValueError(
ValueError: .  Tensor  must have rank 4.  Received rank 3, shape (208, 64, 64)

How to fix this?

My test file is:

import gym
import dreamerv2.api as dv2

config = dv2.defaults.update({
    'logdir': '~/logdir/test',
    'log_every': 1e3,
    'train_every': 10,
    'prefill': 1e5,
    'actor_ent': 3e-3,
    'loss_scales.kl': 1.0,
    'discount': 0.99,
}).parse_flags()


env = gym.make('Test-v0')
dv2.train(env, config)

Thanks! Could you just add a channel dimension of size 1? Otherwise you can do it in the preprocess() function of the agent.