Sanity check: SFT Model should be frozen (PPO)
Apsod opened this issue Β· comments
π Describe the bug
When training a Hydra model, the SFT model used for the reference distribution is not fully frozen due to the fact that the embedding layers are left trainable:
v------------------ WEIGHT TIED ---------------v
Tokens ~~> Embeddings ==> Hidden state ==> Output state ~~> Token Distribution (SFT)
\
~~> Output state ~~> Token Distribution (Policy)
\
~~> Value (Value Head)
~~> Not frozen
==> Frozen
Reproduction
I made a fork with a metric function that uses the reference model to compute the sum of logits for a static dummy input: Apsod@a7e6fc6
Since the SFT-model should be frozen, this reference logit sum should be constant (or vary slightly due to dropout), but this is not the case when running the examples/ppo_sentiments.py-script. I then explicitly froze the embedding layers in freeze_bottom_causal_layers
, which resulted in constant reference logit sum Apsod@7c91b45
Which trlX version are you using?
Commit hash: 9c83cee
Additional system and package information
No response
@cat-state Have you gotten a chance to look at this yet?