araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature Request] Support Optax Optimizer Schedules

bradleypick opened this issue Β· comments

πŸš€ Feature

Support optax optimizer schedules (as argument for learning_rate).

Motivation

Learning rate scheduling can be essential to achieving good results when training agents.

Pitch

Stable Baselines 3 has support for learning rate scheduling and it appears as though doing so in sbx could be achieved by allowing users to pass in optax optimizer schedules when specifying a model.

Alternatives

Unsure (None?)

Additional context

An example of what happens when trying to pass an optax optimizer schedule during construction of TQC.

Running

import optax
import gymnasium as gym

from sbx import TQC, DroQ, SAC, TD3, DDPG

env = gym.make("Pendulum-v1")

lr_schedule = optax.piecewise_constant_schedule(1e-3, boundaries_and_scales={5000: 0.1})

model = TQC("MlpPolicy", env, learning_rate=lr_schedule, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

Yields the following assertion error:

Traceback (most recent call last):
  File "/mnt/sb3/sbx_reprex.py", line 10, in <module>
    model = TQC("MlpPolicy", env, learning_rate=lr_schedule, verbose=1)
  File "/opt/conda/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 111, in __init__
    self._setup_model()
  File "/opt/conda/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 125, in _setup_model
    assert isinstance(self.qf_learning_rate, float)
AssertionError

As far as I can tell, this same behaviour is present in TD3, SAC, DroQ, DDPG.

Can the assertion that the qf_learning_rate is a float (here for TQC) be removed/relaxed to support the use of optax schedules?

### Checklist

  • [βœ“] I have checked that there is no similar issue in the repo (required)

Hello,

Yields the following assertion error:

yes, this is a known limitation of SBX vs SB3.

the current problem with optax schedule is that it depends on number of optimizer steps, where SB3 schedules depends on total_timesteps.
I tried in the past but couldn't manage to dynamically change the learning rate of an optimizer once it was created.

Alternatives

forking SBX...

Hi,

Thanks for your reply.

the current problem with optax schedule is that it depends on number of optimizer steps, where SB3 schedules depends on total_timesteps.

Makes sense - I failed to notice that distinction.

Just curious: in the case where gradient_steps=1, update_freq=1, policy_delay=1, do we end up with equivalence between the count of optimizer steps and total_timesteps?

Am I missing something in thinking a schedule defined in terms of total_timesteps could (at least in theory) be transformed into a schedule defined in terms of optimizer steps using gradient_steps, update_freq, policy_delay?

Practically, this sounds like a pain given how many different types of schedules there are in optax (and how many of them have different signatures).

Alternatives

forking SBX...

Absolutely - I was mainly referring to alternative ways of scheduling the learning rate since I also failed to change it after the optimizer was created.

Just curious: in the case where gradient_steps=1, update_freq=1, policy_delay=1, do we end up with equivalence between the count of optimizer steps and total_timesteps?

it should be, as long as you don't call learn() multiple times.

Am I missing something in thinking a schedule defined in terms of total_timesteps could (at least in theory) be transformed into a schedule defined in terms of optimizer steps using gradient_steps, update_freq, policy_delay?

Another issue is that the optimizer is created before total_timesteps is known... (it's not a hard constrain but this is how it is currently).
But yes, in theory, with train_freq, gradient_steps and total_timesteps, you could potentially have something equivalent (policy_delay is more tricky).

My use case may work with an optax schedule, but it doesn't sound like it'll work well in general.

Thank you for the replies and thank you for these implementations (sbx, sb3, and sb3-contrib) - they are much appreciated.