Troubleshooting: Policy Loading and Execution Issues
jc-bao opened this issue · comments
No consistency between learned policy and test policy
I saved the policy, and the tracking error indicates it to be 7cm
But when I load it back, the controller cannot stabilize the drone anymore.
- are you using the right file? A: yes
- are you using the right function? A: yes
- are you config the environment correctly? A: Does the same
- other network-based policy work or not? A: it works well
- Does the adaptation stage break the compressor? A: no, the compressor itself also does not work
system:
You are ChatGPT, a large language model trained by OpenAI.
Carefully heed the user's instructions.
Respond using Markdown.
user:
I have a policy trained in RL environment. I saved it. When I load it back, it is not working. During the training the policy do converge. I can exclude the factor of wrong saving/loading since another policy with different type of network is actually working after load back.
In the following code, the nn
one works correctly while RMA
one failed. Both tested well inside the simulator. I am wondering what could go wrong in this case.
elif args.controller == 'nn':
from quadjax.train import ActorCritic
network = ActorCritic(env.action_dim, activation='tanh')
control_params = pickle.load(open(f"{quadjax.get_package_path()}/../results/rma/nn_policy.pkl", "rb"))
def apply_fn(train_params, last_obs, env_info):
return network.apply(train_params, last_obs)
controller = controllers.NetworkController(apply_fn, env, control_params)
elif args.controller == 'RMA':
from quadjax.train import ActorCritic, Compressor, Adaptor
network = ActorCritic(env.action_dim, activation='tanh')
compressor = Compressor()
adaptor = Adaptor()
control_params = pickle.load(open(f"{quadjax.get_package_path()}/../results/rma/rma_policy.pkl", "rb"))
def apply_fn(train_params, last_obs, env_info):
adapted_last_obs = adaptor.apply(train_params[2], env_info['obs_adapt'])
obs = jnp.concatenate([last_obs, adapted_last_obs], axis=-1)
pi, value = network.apply(train_params[0], obs)
return pi, value
controller = controllers.NetworkController(apply_fn, env, control_params)
assistant:
There are several factors that could be causing issues with your loaded policies.
Here are a few things to consider, along with some potential next steps:
-
Input format: The saved policy may be expecting inputs in a format different from how the simulator/environment is presenting them, hence failing to give correct outputs. Make sure the same environment states and observations that were used during the training are also used during deployment/testing.
-
Network structure: The policy involves a network. If the structure of the network used at evaluation/inference time is different from the one used during training, loading the policy will likely fail. Verify that the architecture is identical when saving and loading the policy.
-
Initialization: If your networks or adaptors/compressors are initialized or reset during your setup for loading, this might interfere with the saved policy. Ensure no part of your policy network is re-initialized after loading the pre-trained weights.
-
Extra layers such as Adaptor and Compressor: In the RMA case, you may have extra information (adapted observation) being passed into your network
apply_fn
in contrast to the 'nn' case. It seems you are concatenating the last observation and an adapted observation (adapted_last_obs = adaptor.apply(...)
). If this exact same procedure is not carried out during training, that could be the source of discrepancy. -
Code debugging: Try setting breakpoints in your code or printing out the weights/control_params after they are loaded. This can be done by using
print(control_params)
right after the loading of policies (control_params = pickle.load(...)
). Make sure they are notNone
or have unexpected values. -
Differences in version or pickling method: Some packages like numpy, PyTorch, and TensorFlow are not backward compatible. The saved object might have been saved with a different version of these third party libraries than the one being used when loading.
Lastly, I would strongly recommend adding error checking or exception handling where possible. This would help track the problem faster if it relates to a specific step of the code.
user:
After print out the result, now I can see that the saved parameter for RMA
case is the original one rather than the updated one. Here is how I define train_state for RMA:
ppo_train_state = TrainState.create(
apply_fn=get_pi_value_adapt, # for test
params=[network_params, compressor_params, adaptor_params],
tx=tx,
)
Here is the NN case:
ppo_train_state = TrainState.create(
apply_fn=get_pi_value,
params=network_params,
tx=tx,
)
Here is how I apply gradient to them:
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
ppo_train_state.params, traj_batch, advantages, targets
)
ppo_train_state = ppo_train_state.apply_gradients(grads=grads)
Here is how I save them:
with open(f"{quadjax.get_package_path()}/../results/ppo_params_{filename}.pkl", "wb") as f:
pickle.dump(runner_state[0].params, f)
Why for nn
I successfully installed it while for RMA
it just saved the original parameters?
assistant:
The issue you're facing could be due to how you're handling the gradients for each element in your params
parameter list in the RMA
case.
In the RMA
case, params
is a list that contains: [network_params, compressor_params, adaptor_params]
.
In your gradient computation and apply_gradients
call, it is not clear whether gradients for all these individual parameters (network_params
, compressor_params
, adaptor_params
) are being correctly computed and applied. If they aren't, then the initial parameters wouldn't be updated through the training, and saving them would result in saving the original parameters, leading to the problem you're seeing.
So, ensure that gradients are being accurately computed for all these parameters:
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn( ppo_train_state.params, traj_batch, advantages, targets )
In this part of your code, make sure grads
is a list of gradients for each set of parameters in your params
list. This should happen automatically if your jax.value_and_grad(_loss_fn, has_aux=True)
function has been implemented properly to handle a list of parameters as input.
Then, when calling ppo_train_state.apply_gradients(grads=grads)
, make sure that the gradients are applied correctly to all the individual parameters (network_params
, compressor_params
, adaptor_params
) in the params
list. Again, this should happen automatically if apply_gradients
is designed to handle a list of parameters and their corresponding list of gradients.
If the gradient calculations or applications are not being handled correctly for all parameters in your params
list, it could just be updating the network_params
and leaving the compressor_params
and adaptor_params
at their initial state. This would explain why you're observing only the original parameters for the RMA
case when you save and print them out.
user:
But the the policy indeed improved in this case. I think the network should be updated right?
assistant:
If the policy is indeed improving over time, then it indicates that the parameters are being updated during training. The immediate issue could be related to the moment you are saving the policy parameters.
Consider these possibilities:
-
Saving moment: It could be that you are saving your policy parameters after initialization but before any training loop, hence only seeing the initial parameters. Ensure you are saving the parameters post-training.
-
Referencing: Python works with references, so it might be that when you're saving your parameters, you're actually saving a reference to the initial state even though it has changed in memory. One way to potentially debug this could be to print the parameters immediately before saving them.
-
Instance saved: Make sure you are saving the correct instance of the
ppo_train_state
which is being updated. If you have multiple instances ofTrainState
, ensure you are saving the one which is being used in your training. -
Other parts getting updated: From your code, your 'RMA' case includes both
compressor_params
andadaptor_params
on top of the standardnetwork_params
. While thenetwork_params
might be getting updated and improving the policy, the other two parameters sets might not be getting updated properly, leading to the impression that the "original parameters" are saved always. -
Different class behavior: If your
TrainState
class orapply_gradients
method makes distinctions based on the structure of the parameters (for example, behaving differently for a list of parameters versus a single parameter), you might be encountering different behaviors between yournn
andRMA
cases. This is worth checking. -
State tracking: In JAX, it’s not sufficient to simply adjust values within arrays to update parameters. This is because JAX often won’t know to track these changes. You might have to wrap the adjusted state in a function which then returns the updated state. If you are directly changing the trainable parameters or any part of a structure that is supposed to be used by
jax.grad
, it won't be visible to the jax tracer.
user:
ah, I think the problem is resolved. I have two stage training. But in the second stage, I should use the parameter from the first stage but I create the state again from the initial state. After update it, now the problem is resolved.
assistant:
Great! I'm glad that you were able to resolve the issue. That's definitely a tricky one. Starting the second stage with the updated parameters from the first stage would ensure the learning progress carries forward, and not re-initializing the state would keep the parameters updated as expected. Always remember, In reinforcement learning, the stages of training should be sequential and the progress from previous training should be carried forward to the next.
Thanks for updating the situation and good luck with the rest of your project! Should you encounter further issues or have any more questions, you're always welcome to ask.