[Bug] example supplied in readme crashing
Robokan opened this issue · comments
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
If your issue is related to a custom gym environment, please use the custom gym env template.
🐛 Bug
A clear and concise description of what the bug is.
If I run the example in your readme it crashes
To Reproduce
Steps to reproduce the behavior.
import gym
from sbx import TQC, DroQ, SAC, PPO, DQN
env = gym.make("Pendulum-v1")
model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
vec_env.close()
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks
for both code and stack traces.
from stable_baselines3 import ...
Traceback (most recent call last): File ...
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [34], in <cell line: 7>()
3 from sbx import TQC, DroQ, SAC, PPO, DQN
5 env = gym.make("Pendulum-v1")
----> 7 model = TQC("MlpPolicy", env, verbose=1)
8 model.learn(total_timesteps=10_000, progress_bar=True)
10 vec_env = model.get_env()
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:102, in TQC.__init__(self, policy, env, learning_rate, qf_learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, policy_delay, top_quantiles_to_drop_per_net, action_noise, ent_coef, use_sde, sde_sample_freq, use_sde_at_warmup, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
99 self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net
101 if _init_setup_model:
--> 102 self._setup_model()
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:115, in TQC._setup_model(self)
107 if self.policy is None:
108 self.policy = self.policy_class( # pytype:disable=not-instantiable
109 self.observation_space,
110 self.action_space,
111 self.lr_schedule,
112 **self.policy_kwargs, # pytype:disable=not-instantiable
113 )
--> 115 self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
117 self.key, ent_key = jax.random.split(self.key, 2)
119 self.actor = self.policy.actor
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:142, in TQCPolicy.build(self, key, lr_schedule, qf_learning_rate)
137 # Hack to make gSDE work without modifying internal SB3 code
138 self.actor.reset_noise = self.reset_noise
140 self.actor_state = TrainState.create(
141 apply_fn=self.actor.apply,
--> 142 params=self.actor.init(actor_key, obs),
143 tx=self.optimizer_class(learning_rate=lr_schedule(1), **self.optimizer_kwargs),
144 )
146 self.qf = Critic(
147 dropout_rate=self.dropout_rate,
148 use_layer_norm=self.layer_norm,
149 n_units=self.n_units,
150 n_quantiles=self.n_quantiles,
151 )
153 self.qf1_state = RLTrainState.create(
154 apply_fn=self.qf.apply,
155 params=self.qf.init(
(...)
165 tx=optax.adam(learning_rate=qf_learning_rate),
166 )
[... skipping hidden 9 frame]
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:66, in Actor.__call__(self, x)
63 log_std = nn.Dense(self.action_dim)(x)
64 log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
65 dist = TanhTransformedDistribution(
---> 66 tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
67 )
68 return dist
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py:235, in MultivariateNormalDiag.__init__(self, loc, scale_diag, scale_identity_multiplier, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
232 if scale_diag is not None:
233 diag_cls = (KahanLogDetLinOpDiag if experimental_use_kahan_sum else
234 tf.linalg.LinearOperatorDiag)
--> 235 scale = diag_cls(
236 diag=scale_diag,
237 is_non_singular=True,
238 is_self_adjoint=True,
239 is_positive_definite=False)
240 else:
241 # Deprecated behavior; breaks variable-safety rules by calling
242 # `tf.shape(loc)`.
243 num_rows = tf.compat.dimension_value(loc.shape[-1])
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:191, in LinearOperatorDiag.__init__(self, diag, is_non_singular, is_self_adjoint, is_positive_definite, is_square, name)
182 super(LinearOperatorDiag, self).__init__(
183 dtype=self._diag.dtype,
184 is_non_singular=is_non_singular,
(...)
188 parameters=parameters,
189 name=name)
190 # TODO(b/143910018) Remove graph_parents in V3.
--> 191 self._set_graph_parents([self._diag])
File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py:1177, in LinearOperator._set_graph_parents(self, graph_parents)
1174 for i, t in enumerate(graph_parents):
1175 if t is None or not (linear_operator_util.is_ref(t) or
1176 ops.is_tensor(t)):
-> 1177 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
1178 self._graph_parents = graph_parents
ValueError: Graph parent item 0 is not a Tensor; [[0.48654944]].
### Expected behavior
A clear and concise description of what you expected to happen.
if I import stable baselines PPO or others they train the example perfectly. Expecting SBX to do the same.
### System Info
Describe the characteristic of your environment:
Operating system is MacOS Monterey
* Describe how the library was installed (pip, docker, source, ...)
pip
* GPU models and configuration
* Python version
Python 3.9.15
* PyTorch version
* torch 1.13.1
* Gym version
gym 0.21.0
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
Additional context
Add any other context about the problem here.
Checklist
- [x ] I have checked that there is no similar issue in the repo (required)
- [x ] I have read the documentation (required)
- [x ] I have provided a minimal working example to reproduce the bug (required)
Hello,
i guess if you use another os (for instance linux in a google colab), it does work?
make sure to have latest version of tensorflow proba and jax, support for mac os might be experimental.
I have the latest JAX working great on macOS. I am even running the Brax simulator on top of it. Your repo would be great for speeding up my stable-baselines/Gym code. I really appreciate you posting it. Any ideas of how to go about debugging this?
I did try it in on Google Colab and your example does work there.
I did try it in on Google Colab and your example does work there.
I have the latest JAX working great on macOS.
Then it seems that it comes from tensorflow probability (which is in jax only despite the name), you should probably open an issue there.
Thanks. I will see if there is anything about this in JAX.