num_env_step not work in RolloutWrapper
jinPrelude opened this issue · comments
Euijin Jeong commented
Issue
environment made by RolloutWrapper doesn't reflect num_env_step
variable which we put into RolloutWrapper
:
Code for reproduction
from gymnax.experimental import RolloutWrapper
import jax
ENV_NUM = 3
manager = RolloutWrapper(None, env_name='CartPole-v1', num_env_steps=100)
rng, rollout_rng = jax.random.split(jax.random.key(0))
rollout_rng = jax.random.split(rollout_rng, ENV_NUM)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(rollout_rng, None)
print(done.shape) # it should print (3, 100), but the result is (3, 500)