araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

self.key is never updated

theovincent opened this issue · comments

Thank you for your work on this cool repo! It is really useful for my research :)

🐛 Bug

Why is self.key always the same after each self._train call? More precisely, why is this part of the code, coded like this

sbx/sbx/sac/sac.py

Lines 446 to 449 in fcd647e

update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),

and not like this

update_carry["actor_state"], 
update_carry["ent_coef_state"], 
update_carry["key"],  # Return the new updated key
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]), 

?

To Reproduce

Edit the method train in the file sbx/sac/sac.py to add the following line of code after the function _train has been called:

print("self.key", self.key)

Example: https://github.com/theovincent/sbx/blob/8327b98463c89b68f17ec0431d0cf3069cb7d7a7/sbx/sac/sac.py#L236

Create the following file, called train.py at the top level of the project:

import gymnasium as gym

from sbx import SAC

env = gym.make("Pendulum-v1")

model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=110, progress_bar=True)

Running python train.py in the terminal yields

>>> python train.py
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]

Expected behavior

Changing the code, as suggested earlier, fixes the problem. Here are the logs when the change is implemented:

>>> python train.py
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
self.key [3440514203 2996688322]
self.key [ 507603733 1743734701]
self.key [1106737823 3095002064]
self.key [ 372788615 2111558586]
self.key [1808065049 3808616220]
self.key [1837019053 2754803453]
self.key [1740029140 3438719296]
self.key [1088489055 1273990256]
self.key [3718340890 2050508589]
self.key [1872112782 1422931421]

### System Info

  • Describe how the library was installed (pip, docker, source, ...)
    Fork the repo, clone it, create a python virtual env, install the dependencies
python3 -m venv env
source env/bin/activate
pip install -e .
pip install gymnasium[classic-control]
  • GPU models and configuration
    The GPU is not used
  • pip version
    23.2.1
>>> import stable_baselines3 as sb3
>>> sb3.get_system_info()
- OS: Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2
- Python: 3.11.5
- Stable-Baselines3: 2.3.0
- PyTorch: 2.2.2+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

({'OS': 'Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2', 'Python': '3.11.5', 'Stable-Baselines3': '2.3.0', 'PyTorch': '2.2.2+cu121', 'GPU Enabled': 'True', 'Numpy': '1.26.4', 'Cloudpickle': '3.0.0', 'Gymnasium': '0.29.1'}, '- OS: Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2\n- Python: 3.11.5\n- Stable-Baselines3: 2.3.0\n- PyTorch: 2.2.2+cu121\n- GPU Enabled: True\n- Numpy: 1.26.4\n- Cloudpickle: 3.0.0\n- Gymnasium: 0.29.1\n')

Additional context

Before commit e564074, the key was updated each time the function self._train was called as you can see here:

sbx/sbx/sac/sac.py

Lines 389 to 392 in 0f9163d

actor_state,
ent_coef_state,
key,
(actor_loss_value, qf_loss_value, ent_coef_value),

This bug seems to be present for:

  • CrossQ
  • SAC
  • TD3
  • TQC

Checklist

  • [ X] I have checked that there is no similar issue in the repo (required)
  • [ X] I have read the documentation (required)
  • [ X] I have provided a minimal working example to reproduce the bug (required)

Hello,
thanks for reporting the bug =)
Can you have a quick look if it impacts performance too?

I guess it was introduced in #21

Hi @araffin,

Thank you for your quick answer. I started to launch some runs here, it is still running: https://wandb.ai/theovincent/update_key/workspace?nw=nwusertheovincent

SAC_static_key is when the key is never updated [old version]
SAC_updated_key is when the key gets updated [new version]