`Pong-misc`: TypeError: select cases must have the same shapes, got [(30, 40), ()].
HelgeS opened this issue · comments
HelgeS commented
When running the Pong-misc
environment, the following error is raised from move_paddles
.
I tried both the example notebook and gymnax-blines
to ensure it's not an usage error.
Below is the stack trace and the gymnax-blines configuration I have used.
$ python train.py -config agents/Pong-misc/ppo.yaml
PPO: 0%| | 0/18751 [00:00<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 76, in <module>
main(
File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 24, in main
log_steps, log_return, network_ckpt = train_fn(
File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 271, in train_ppo
train_state, obs, state, batch, rng_step = get_transition(
File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 252, in get_transition
next_obs, next_state, reward, done, _ = rollout_manager.batch_step(
File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 138, in batch_step
return jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/environment.py", line 45, in step
obs_st, state_st, reward, done, info = self.step_env(key, state, action, params)
File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 75, in step_env
state = move_paddles(
File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 356, in move_paddles
new_center_p2 = jax.lax.select(use_ai_policy, new_center_ai, new_center_self)
TypeError: select cases must have the same shapes, got [(30, 40), ()].
Configuration (copied from CartPole-v1):
train_config:
train_type: "PPO"
num_train_steps: 150000
evaluate_every_epochs: 1000
env_name: "Pong-misc"
env_kwargs: {}
env_params: {}
num_test_rollouts: 164
num_train_envs: 8 # Number of parallel env workers
max_grad_norm: 0.5 # Global norm to clip gradients by
gamma: 0.99 # Discount factor
n_steps: 32 # "GAE n-steps"
n_minibatch: 4 # "Number of PPO minibatches"
lr_begin: 5e-04 # Start PPO learning rate
lr_end: 5e-04 # End PPO learning rate
lr_warmup: 0.05 # Prop epochs until warmup is completed
epoch_ppo: 4 # "Number of PPO epochs on a single batch"
clip_eps: 0.2 # "Clipping range"
gae_lambda: 0.95 # "GAE lambda"
entropy_coeff: 0.01 # "Entropy loss coefficient"
critic_coeff: 0.5 # "Value loss coefficient"
network_name: "Categorical-MLP"
network_config:
num_hidden_units: 64
num_hidden_layers: 2
log_config:
time_to_track: ["num_steps"]
what_to_track: ["return"]
verbose: false
print_every_k_updates: 1
overwrite: 1
model_type: "jax"
device_config:
num_devices: 1
device_type: "gpu"