RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issue: vmapped CartPole input shape does not match

DriesSmit opened this issue · comments

Hello there. I am trying to run a vmapped CartPole step function. My environment state inputs are of the shape:

env_state:
[executor/0] x:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] x_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>

When I run jnp.array([env_state.x, env_state.x_dot, env_state.theta, env_state.theta_dot]) on the state, before the environment step, and get out:
Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/0)>

However when I try to run the step function I get:

obs, env_state, rewards, done, _ = self.env.step(key_step, env_state, action, self.env_params)
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/environment.py", line 38, in step
[executor/0]     obs_st, state_st, reward, done, info = self.step_env(
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 83, in step_env
[executor/0]     lax.stop_gradient(self.get_obs(state)),
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 108, in get_obs
[executor/0]     return jnp.array([state.x, state.x_dot, state.theta, state.theta_dot])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1889, in array
[executor/0]     out = stack([asarray(elt, dtype=dtype) for elt in object])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1634, in stack
[executor/0]     raise ValueError("All input arrays must have the same shape.")
[executor/0] ValueError: All input arrays must have the same shape.

Do you have any idea what might be causing this issue? Is the shapes somehow changing inside the step function? Thanks.

Apologies, it was not to do with the environment. It was because I was passing an action logit array of size 2 to the environment.