perrin-isir / xpag

a modular reinforcement learning library with JAX agents

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

error about SDQN

KID0031 opened this issue · comments

Hi, I'm a beginner with xpag, when running SDQN in HalfCheetah-v4 using train_mujoco.ipynb in xpag-tutorial, this error occurs

ScopeParamShapeError                      Traceback (most recent call last)
[<ipython-input-9-366b04f4f1a8>](https://localhost:8080/#) in <cell line: 1>()
----> 1 learn(
      2     env,
      3     eval_env,
      4     env_info,
      5     agent,

7 frames
[/usr/local/lib/python3.10/dist-packages/xpag/tools/learn.py](https://localhost:8080/#) in learn(env, eval_env, env_info, agent, buffer, setter, batch_size, gd_steps_per_step, start_training_after_x_steps, max_steps, evaluate_every_x_steps, save_agent_every_x_steps, save_dir, save_episode, plot_projection, custom_eval_function, additional_step_keys, seed)
    138             if i > 0:
    139                 for _ in range(max(gd_steps_per_step * env_info["num_envs"], 1)):
--> 140                     _ = agent.train_on_batch(buffer.sample(batch_size))
    141 
    142         action = datatype_convert(action, env_datatype)

[/usr/local/lib/python3.10/dist-packages/xpag/agents/sdqn/sdqn.py](https://localhost:8080/#) in train_on_batch(self, batch)
    544         mask = 1 - batch["terminated"]
    545 
--> 546         self.training_state, metrics = self.update_step(
    547             self.action_bins,
    548             self.action_dim,

    [... skipping hidden 12 frame]

[/usr/local/lib/python3.10/dist-packages/xpag/agents/sdqn/sdqn.py](https://localhost:8080/#) in update_step(action_bins, action_dim, state, observations, actions, rewards, new_observations, mask)
    414             key, key_critic = jax.random.split(state.key, 2)
    415 
--> 416             critic_up_l, critic_up_grads = self.critic_up_grad(
    417                 state.critic_up_params,
    418                 state.target_critic_up_params,

    [... skipping hidden 8 frame]

[/usr/local/lib/python3.10/dist-packages/xpag/agents/sdqn/sdqn.py](https://localhost:8080/#) in critic_up_loss(critic_up_params, target_critic_up_params, critic_low_params, observations, actions, new_observations, rewards, mask)
    305 
    306             # Get current Q estimates
--> 307             current_q1, current_q2 = self.critic_up.apply(
    308                 critic_up_params, observations, actions, 1
    309             )

    [... skipping hidden 6 frame]

[/usr/local/lib/python3.10/dist-packages/xpag/agents/sdqn/sdqn.py](https://localhost:8080/#) in __call__(self, obs, actions, output_size)
    164                     res = []
    165                     for _ in range(self.n_critics):
--> 166                         q = CustomMLP(
    167                             layer_sizes=hidden_layer_sizes + (output_size,),
    168                             activation=linen.relu,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.10/dist-packages/xpag/agents/sdqn/sdqn.py](https://localhost:8080/#) in __call__(self, data)
    113                 hidden = data
    114                 for i, hidden_size in enumerate(self.layer_sizes):
--> 115                     hidden = linen.Dense(
    116                         hidden_size,
    117                         name=f"hidden_{i}",

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
    194       The transformed input.
    195     """
--> 196     kernel = self.param('kernel',
    197                         self.kernel_init,
    198                         (jnp.shape(inputs)[-1], self.features),

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/flax/core/scope.py](https://localhost:8080/#) in param(self, name, init_fn, unbox, *init_args)
    833         # for inference to a half float type for example.
    834         if jnp.shape(val) != jnp.shape(abs_val):
--> 835           raise errors.ScopeParamShapeError(name, self.path_text,
    836                                             jnp.shape(abs_val), jnp.shape(val))
    837     else:

ScopeParamShapeError: Initializer expected to generate shape (47, 256) but got shape (23, 256) instead for parameter "kernel" in "/CustomMLP_0/hidden_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

And I only change the agent to SDQN type in the code:

agent = SDQN(
    env_info['observation_dim'] if not env_info['is_goalenv']
    else env_info['observation_dim'] + env_info['desired_goal_dim'],
    env_info['action_dim'],
    {
    }
)

Is there any ways to solve it? Thanks :)

Hi,
SDQN is a bit special, because it changes the representation of actions (the action space is discretized into bins, and actions are represented as one-hot encodings relative to those bins). For this reason it comes with its own "setter" that takes care of this : SDQNSetter.
So, you should import it:

from xpag.agents import SDQNSetter

and then replace the line setter = DefaultSetter() by:

setter = SDQNSetter(agent)

That should work.

But FYI, the development of SDQN is still ongoing, as I am not 100% satisfied with it. It kind of works, but based on results in the original paper I believe it can be improved (of course it still won't be competitive with recent algorithms, but that's not the point).
That's why I didn't put an example with SDQN in the tutorials yet. I will make one when I'm happy with the results.

Here is an example of run I did today on HalfCheetah-v4:

[ 0 steps] [training time (ms) += 0 ] [ep reward: -345.233 ]
[ 5000 steps] [training time (ms) += 5324 ] [ep reward: -294.460 ]
[ 10000 steps] [training time (ms) += 4092 ] [ep reward: -373.125 ]
[ 15000 steps] [training time (ms) += 40518 ] [ep reward: -48.540 ]
[ 20000 steps] [training time (ms) += 25724 ] [ep reward: 1844.875 ]
[ 25000 steps] [training time (ms) += 25391 ] [ep reward: 2075.772 ]
[ 30000 steps] [training time (ms) += 26777 ] [ep reward: 3475.885 ]
[ 35000 steps] [training time (ms) += 28484 ] [ep reward: 3386.266 ]
[ 40000 steps] [training time (ms) += 26404 ] [ep reward: 1223.198 ]
[ 45000 steps] [training time (ms) += 25733 ] [ep reward: 2478.572 ]
[ 50000 steps] [training time (ms) += 25467 ] [ep reward: 1755.038 ]
[ 55000 steps] [training time (ms) += 25508 ] [ep reward: 3465.681 ]
[ 60000 steps] [training time (ms) += 25220 ] [ep reward: 3445.830 ]
[ 65000 steps] [training time (ms) += 25258 ] [ep reward: 3554.391 ]
[ 70000 steps] [training time (ms) += 25641 ] [ep reward: 3545.330 ]
[ 75000 steps] [training time (ms) += 25769 ] [ep reward: 3914.521 ]
[ 80000 steps] [training time (ms) += 25483 ] [ep reward: 3647.373 ]
[ 85000 steps] [training time (ms) += 25428 ] [ep reward: 3819.018 ]
...
[ 280000 steps] [training time (ms) += 30184 ] [ep reward: 3863.470 ]
[ 285000 steps] [training time (ms) += 39707 ] [ep reward: 3794.748 ]
[ 290000 steps] [training time (ms) += 38752 ] [ep reward: 3495.448 ]
[ 295000 steps] [training time (ms) += 25772 ] [ep reward: 4094.522 ]
...
[ 605000 steps] [training time (ms) += 25124 ] [ep reward: 4213.886 ]
[ 610000 steps] [training time (ms) += 25106 ] [ep reward: 4163.989 ]
[ 615000 steps] [training time (ms) += 25136 ] [ep reward: 4243.774 ]
[ 620000 steps] [training time (ms) += 25121 ] [ep reward: 4112.370 ]
[ 625000 steps] [training time (ms) += 25096 ] [ep reward: 4061.563 ]
...
[ 710000 steps] [training time (ms) += 25072 ] [ep reward: 4194.871 ]
[ 715000 steps] [training time (ms) += 25058 ] [ep reward: 4050.767 ]
[ 720000 steps] [training time (ms) += 25031 ] [ep reward: 4099.980 ]
[ 725000 steps] [training time (ms) += 25077 ] [ep reward: 3991.280 ]
[ 730000 steps] [training time (ms) += 25092 ] [ep reward: 3816.944 ]
[ 735000 steps] [training time (ms) += 25085 ] [ep reward: 4006.426 ]

In the original paper, on HalfCheetah (not v4 but that shouldn't make a big difference), the reward goes above 6000 after ~600k steps, so it should not be stuck at around 4000 like in my case. My implementation does not exactly match the one presented in the paper, so maybe I'll have to remove some of the differences. I'd like results to be at least as good as the original ones.

I also need to write some documentation about setters...

Thanks, it works! I also want to reimplement SDQN but can’t achieve its performance. That’s why I searched for it on Github but there is only little code about SDQN :(

Anyway, thanks again and I will close this issue.