google-deepmind / acme

A library of reinforcement learning components and agents

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

run_mogmpo.py -- ValueError: Graph parent item 0 is not a Tensor;

dyth opened this issue · comments

commented

Hi there,

Packages were installed according to:

conda create -n acme python=3.8
conda activate acme
pip install --upgrade pip setuptools wheel
pip install .[jax,tf,testing,envs]
pip install mujoco
pip install mujoco-py
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install tensorflow-gpu

It should be noted that python run_sac.py runs without problem, but python run_mogmpo.py --env_name=gym:HalfCheetah-v4 fails with:

/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/__init__.py:57: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  if (distutils.version.LooseVersion(tf.__version__) <
I0113 18:54:41.923464 140485061584704 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0113 18:54:41.976754 140485061584704 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter
I0113 18:54:41.977056 140485061584704 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0113 18:54:41.977239 140485061584704 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0113 18:54:42.811571 140471527532288 courier_utils.py:120] Binding: run
I0113 18:54:42.812238 140471527532288 lp_utils.py:87] StepsLimiter: Starting with max_steps = 1000000 (actor_steps)
I0113 18:54:42.813285 140471535924992 savers.py:164] Attempting to restore checkpoint: None
I0113 18:54:42.816575 140469841417984 node.py:61] Reverb client connecting to: localhost:43247
I0113 18:54:42.817500 140469321332480 node.py:61] Reverb client connecting to: localhost:43247
I0113 18:54:42.818590 140469296154368 node.py:61] Reverb client connecting to: localhost:43247
I0113 18:54:42.819234 140469270976256 node.py:61] Reverb client connecting to: localhost:43247
I0113 18:54:42.830212 140469170329344 node.py:61] Reverb client connecting to: localhost:43247
I0113 18:54:42.836279 140471535924992 courier_utils.py:120] Binding: get_counts
I0113 18:54:42.836411 140471535924992 courier_utils.py:120] Binding: get_directory
I0113 18:54:42.836471 140471535924992 courier_utils.py:120] Binding: get_steps_key
I0113 18:54:42.836519 140471535924992 courier_utils.py:120] Binding: increment
I0113 18:54:42.836561 140471535924992 courier_utils.py:120] Binding: restore
I0113 18:54:42.836602 140471535924992 courier_utils.py:120] Binding: save
I0113 18:54:42.836844 140471535924992 savers.py:155] Saving checkpoint: /home/dyth/acme/20230113-185439/checkpoints/counter
I0113 18:54:43.820627 140471527532288 lp_utils.py:95] StepsLimiter: Reached 0 recorded steps
/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/gym/core.py:329: DeprecationWarning: WARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.
  deprecation(
/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: WARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.
  deprecation(
I0113 18:54:43.954088 140469270976256 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/actor/logs.csv
I0113 18:54:43.954318 140469270976256 courier_utils.py:120] Binding: run
I0113 18:54:43.954387 140469270976256 courier_utils.py:120] Binding: run_episode
I0113 18:54:43.958663 140469136758528 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/evaluator/logs.csv
I0113 18:54:43.958981 140469136758528 courier_utils.py:120] Binding: run
I0113 18:54:43.959156 140469296154368 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/actor/logs.csv
I0113 18:54:43.959258 140469136758528 courier_utils.py:120] Binding: run_episode
I0113 18:54:43.959641 140469321332480 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/actor/logs.csv
I0113 18:54:43.959756 140469296154368 courier_utils.py:120] Binding: run
I0113 18:54:43.959851 140469170329344 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/actor/logs.csv
I0113 18:54:43.972820 140469321332480 courier_utils.py:120] Binding: run
I0113 18:54:43.974417 140469296154368 courier_utils.py:120] Binding: run_episode
I0113 18:54:43.977043 140471049455360 builder.py:251] Creating off-policy replay buffer with replay fraction 1 of batch 256
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmp/tmphf02e4lq.
[reverb/cc/platform/tfrecord_checkpointer.cc:552] Loading latest checkpoint from /tmp/tmphf02e4lq
[reverb/cc/platform/default/server.cc:71] Started replay server on port 43247
I0113 18:54:43.995313 140469170329344 courier_utils.py:120] Binding: run
I0113 18:54:44.145094 140469321332480 courier_utils.py:120] Binding: run_episode
I0113 18:54:44.192413 140469170329344 courier_utils.py:120] Binding: run_episode
I0113 18:54:44.909356 140469841417984 csv.py:76] Logging to /home/dyth/acme/20230113-185439/logs/learner/logs.csv
I0113 18:54:44.909640 140469841417984 learning.py:126] Learner process id: 0. Devices passed: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
I0113 18:54:44.909731 140469841417984 learning.py:128] Learner process id: 0. Local devices from JAX API: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2040306) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2040306) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2040306) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2040306) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2040306) so Table priority_table is accessed directly without gRPC.
Node ThreadWorker(thread=<Thread(learner, stopped daemon 140469841417984)>, future=<Future at 0x7fc49809c520 state=finished raised ValueError>) crashed:
Traceback (most recent call last):
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 474, in _check_workers
    worker.future.result()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 437, in result
    return self.__get_result()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 250, in run_inner
    future.set_result(f())
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 75, in _construct_function
    return functools.partial(self._function, *args, **kwargs)()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/courier/node.py", line 113, in run
    instance = self._construct_instance()  # pytype:disable=wrong-arg-types
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 180, in _construct_instance
    self._instance = self._constructor(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/experiments/make_distributed_experiment.py", line 178, in build_learner
    learner = experiment.builder.make_learner(random_key, networks, iterator,
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/builder.py", line 167, in make_learner
    learner = learning.MPOLearner(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/learning.py", line 221, in __init__
    network_params, _ = mpo_networks.init_params(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 118, in init_params
    params_policy_head = networks.policy_head.init(rng_keys[2], embeddings)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 114, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 338, in init_fn
    f(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 217, in policy_fn
    return networks_lib.MultivariateNormalDiagHead(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
    out = f(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/networks/distributional.py", line 297, in __call__
    return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 235, in __init__
    scale = diag_cls(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py", line 191, in __init__
    self._set_graph_parents([self._diag])
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py", line 1178, in _set_graph_parents
    raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
ValueError: Graph parent item 0 is not a Tensor; [[0.500001 0.500001 0.500001 0.500001 0.500001 0.500001]].
I0113 18:54:46.687932 140485061584704 savers.py:206] Caught SIGTERM: forcing a checkpoint save.
I0113 18:54:46.688052 140485061584704 savers.py:155] Saving checkpoint: /home/dyth/acme/20230113-185439/checkpoints/counter
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 428, in wait
    raise failure
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 474, in _check_workers
    worker.future.result()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 437, in result
    return self.__get_result()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 250, in run_inner
    future.set_result(f())
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 75, in _construct_function
    return functools.partial(self._function, *args, **kwargs)()
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/courier/node.py", line 113, in run
    instance = self._construct_instance()  # pytype:disable=wrong-arg-types
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 180, in _construct_instance
    self._instance = self._constructor(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/experiments/make_distributed_experiment.py", line 178, in build_learner
    learner = experiment.builder.make_learner(random_key, networks, iterator,
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/builder.py", line 167, in make_learner
    learner = learning.MPOLearner(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/learning.py", line 221, in __init__
    network_params, _ = mpo_networks.init_params(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 118, in init_params
    params_policy_head = networks.policy_head.init(rng_keys[2], embeddings)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 114, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 338, in init_fn
    f(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 217, in policy_fn
    return networks_lib.MultivariateNormalDiagHead(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
    out = f(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/networks/distributional.py", line 297, in __call__
    return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 235, in __init__
    scale = diag_cls(
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py", line 191, in __init__
    self._set_graph_parents([self._diag])
  File "/home/dyth/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py", line 1178, in _set_graph_parents
    raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
ValueError: Graph parent item 0 is not a Tensor; [[0.500001 0.500001 0.500001 0.500001 0.500001 0.500001]].

Any help would be much appreciated!

I recently got this error as well, so I second this issue. Any guidance would be very useful.

The issue is likely due to some breaking changes in jax>=0.4 which is no longer compatible with the TFP version pinned in acme. Downgrading JAX to something like 0.3.25 should fix it.