LeCAR-Lab / CoVO-MPC

Official implementation for the paper "CoVO-MPC: Theoretical Analysis of Sampling-based MPC and Optimal Covariance Design" accepted by L4DC 2024. CoVO-MPC is an optimal sampling-based MPC algorithm.

Home Page:https://lecar-lab.github.io/CoVO-MPC/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Troubleshooting: Policy Loading and Execution Issues

jc-bao opened this issue · comments

No consistency between learned policy and test policy

I saved the policy, and the tracking error indicates it to be 7cm

But when I load it back, the controller cannot stabilize the drone anymore.

  • are you using the right file? A: yes
  • are you using the right function? A: yes
  • are you config the environment correctly? A: Does the same
  • other network-based policy work or not? A: it works well
  • Does the adaptation stage break the compressor? A: no, the compressor itself also does not work

system:

You are ChatGPT, a large language model trained by OpenAI.
Carefully heed the user's instructions.
Respond using Markdown.


user:

I have a policy trained in RL environment. I saved it. When I load it back, it is not working. During the training the policy do converge. I can exclude the factor of wrong saving/loading since another policy with different type of network is actually working after load back.

In the following code, the nn one works correctly while RMA one failed. Both tested well inside the simulator. I am wondering what could go wrong in this case.

    elif args.controller == 'nn':
        from quadjax.train import ActorCritic
        network = ActorCritic(env.action_dim, activation='tanh')
        control_params = pickle.load(open(f"{quadjax.get_package_path()}/../results/rma/nn_policy.pkl", "rb"))
        def apply_fn(train_params, last_obs, env_info):
            return network.apply(train_params, last_obs)
        controller = controllers.NetworkController(apply_fn, env, control_params)
    elif args.controller == 'RMA':
        from quadjax.train import ActorCritic, Compressor, Adaptor
        network = ActorCritic(env.action_dim, activation='tanh')
        compressor = Compressor()
        adaptor = Adaptor()
        control_params = pickle.load(open(f"{quadjax.get_package_path()}/../results/rma/rma_policy.pkl", "rb"))
        def apply_fn(train_params, last_obs, env_info):
            adapted_last_obs = adaptor.apply(train_params[2], env_info['obs_adapt'])
            obs = jnp.concatenate([last_obs, adapted_last_obs], axis=-1)
            pi, value = network.apply(train_params[0], obs)
            return pi, value
        controller = controllers.NetworkController(apply_fn, env, control_params)

assistant:

There are several factors that could be causing issues with your loaded policies.
Here are a few things to consider, along with some potential next steps:

  1. Input format: The saved policy may be expecting inputs in a format different from how the simulator/environment is presenting them, hence failing to give correct outputs. Make sure the same environment states and observations that were used during the training are also used during deployment/testing.

  2. Network structure: The policy involves a network. If the structure of the network used at evaluation/inference time is different from the one used during training, loading the policy will likely fail. Verify that the architecture is identical when saving and loading the policy.

  3. Initialization: If your networks or adaptors/compressors are initialized or reset during your setup for loading, this might interfere with the saved policy. Ensure no part of your policy network is re-initialized after loading the pre-trained weights.

  4. Extra layers such as Adaptor and Compressor: In the RMA case, you may have extra information (adapted observation) being passed into your network apply_fn in contrast to the 'nn' case. It seems you are concatenating the last observation and an adapted observation (adapted_last_obs = adaptor.apply(...)). If this exact same procedure is not carried out during training, that could be the source of discrepancy.

  5. Code debugging: Try setting breakpoints in your code or printing out the weights/control_params after they are loaded. This can be done by using print(control_params) right after the loading of policies (control_params = pickle.load(...)). Make sure they are not None or have unexpected values.

  6. Differences in version or pickling method: Some packages like numpy, PyTorch, and TensorFlow are not backward compatible. The saved object might have been saved with a different version of these third party libraries than the one being used when loading.

Lastly, I would strongly recommend adding error checking or exception handling where possible. This would help track the problem faster if it relates to a specific step of the code.


user:

After print out the result, now I can see that the saved parameter for RMA case is the original one rather than the updated one. Here is how I define train_state for RMA:

ppo_train_state = TrainState.create(
                apply_fn=get_pi_value_adapt, # for test
                params=[network_params, compressor_params, adaptor_params],
                tx=tx,
            )

Here is the NN case:

ppo_train_state = TrainState.create(
                apply_fn=get_pi_value,
                params=network_params,
                tx=tx,
            )

Here is how I apply gradient to them:

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        ppo_train_state.params, traj_batch, advantages, targets
                    )
                    ppo_train_state = ppo_train_state.apply_gradients(grads=grads)

Here is how I save them:

    with open(f"{quadjax.get_package_path()}/../results/ppo_params_{filename}.pkl", "wb") as f:
        pickle.dump(runner_state[0].params, f)

Why for nn I successfully installed it while for RMA it just saved the original parameters?


assistant:

The issue you're facing could be due to how you're handling the gradients for each element in your params parameter list in the RMA case.

In the RMA case, params is a list that contains: [network_params, compressor_params, adaptor_params].

In your gradient computation and apply_gradients call, it is not clear whether gradients for all these individual parameters (network_params, compressor_params, adaptor_params) are being correctly computed and applied. If they aren't, then the initial parameters wouldn't be updated through the training, and saving them would result in saving the original parameters, leading to the problem you're seeing.

So, ensure that gradients are being accurately computed for all these parameters:

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn( ppo_train_state.params, traj_batch, advantages, targets )

In this part of your code, make sure grads is a list of gradients for each set of parameters in your params list. This should happen automatically if your jax.value_and_grad(_loss_fn, has_aux=True) function has been implemented properly to handle a list of parameters as input.

Then, when calling ppo_train_state.apply_gradients(grads=grads), make sure that the gradients are applied correctly to all the individual parameters (network_params, compressor_params, adaptor_params) in the params list. Again, this should happen automatically if apply_gradients is designed to handle a list of parameters and their corresponding list of gradients.

If the gradient calculations or applications are not being handled correctly for all parameters in your params list, it could just be updating the network_params and leaving the compressor_params and adaptor_params at their initial state. This would explain why you're observing only the original parameters for the RMA case when you save and print them out.


user:

But the the policy indeed improved in this case. I think the network should be updated right?


assistant:

If the policy is indeed improving over time, then it indicates that the parameters are being updated during training. The immediate issue could be related to the moment you are saving the policy parameters.

Consider these possibilities:

  1. Saving moment: It could be that you are saving your policy parameters after initialization but before any training loop, hence only seeing the initial parameters. Ensure you are saving the parameters post-training.

  2. Referencing: Python works with references, so it might be that when you're saving your parameters, you're actually saving a reference to the initial state even though it has changed in memory. One way to potentially debug this could be to print the parameters immediately before saving them.

  3. Instance saved: Make sure you are saving the correct instance of the ppo_train_state which is being updated. If you have multiple instances of TrainState, ensure you are saving the one which is being used in your training.

  4. Other parts getting updated: From your code, your 'RMA' case includes both compressor_params and adaptor_params on top of the standard network_params. While the network_params might be getting updated and improving the policy, the other two parameters sets might not be getting updated properly, leading to the impression that the "original parameters" are saved always.

  5. Different class behavior: If your TrainState class or apply_gradients method makes distinctions based on the structure of the parameters (for example, behaving differently for a list of parameters versus a single parameter), you might be encountering different behaviors between your nn and RMA cases. This is worth checking.

  6. State tracking: In JAX, it’s not sufficient to simply adjust values within arrays to update parameters. This is because JAX often won’t know to track these changes. You might have to wrap the adjusted state in a function which then returns the updated state. If you are directly changing the trainable parameters or any part of a structure that is supposed to be used by jax.grad, it won't be visible to the jax tracer.


user:

ah, I think the problem is resolved. I have two stage training. But in the second stage, I should use the parameter from the first stage but I create the state again from the initial state. After update it, now the problem is resolved.


assistant:

Great! I'm glad that you were able to resolve the issue. That's definitely a tricky one. Starting the second stage with the updated parameters from the first stage would ensure the learning progress carries forward, and not re-initializing the state would keep the parameters updated as expected. Always remember, In reinforcement learning, the stages of training should be sequential and the progress from previous training should be carried forward to the next.

Thanks for updating the situation and good luck with the rest of your project! Should you encounter further issues or have any more questions, you're always welcome to ask.