araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Question] Extending sbx algorithms (e.g via a callback)

asmith26 opened this issue · comments

Hi there,

I'm trying to experiment with "RL while learning Minmax penalty" (paper, code), and I thought I'd try adding it to a sbx Droq setup. From the paper, the implementation looks quite straightforward, essentially:

for each step:
    penalty = minmaxpenalty.update(reward, Q[state])
    if info["unsafe"]:
        reward = penalty

hence I need to obtain the Q-value. I've been looking into the Droq code and I believe the Q-value is computed at (?)

next_target_quantiles = next_quantiles[:, :n_target_quantiles]
I've also been looking into trying to implement this via a StableBaselines callback, but can't seem to get it to work (not sure if this is a suitable use-case?)

Many thanks for any help, and for this fantastic lib! :)

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

hence I need to obtain the Q-value. I've been looking into the Droq code and I believe the Q-value is computed at (?)

You should probably take a look at SAC first, this DroQ implementation is based on TQC (with quantiles) and this makes it slightly more complicated (just a mean along the correct axis) to have the q-value.
SAC in SBX can also be used in the DroQ configuration (need to activate dropout + LN and add policy delay with several gradient steps).

From the paper, the implementation looks quite straightforward, essentially:

is that during data collection only?
Otherwise the easiest would probably be to fork the repo.

Thanks for the tip regarding SAC first, good idea.

is that during data collection only?

From the paper it says: We propose a simple model-free algorithm for estimating this penalty online, which can be integrated into any RL pipeline that learns value functions

Many thanks again for your help!