CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sanity check: SFT Model should be frozen (PPO)

Apsod opened this issue Β· comments

commented

πŸ› Describe the bug

When training a Hydra model, the SFT model used for the reference distribution is not fully frozen due to the fact that the embedding layers are left trainable:

        v------------------  WEIGHT TIED  ---------------v
Tokens ~~> Embeddings ==> Hidden state ==> Output state ~~> Token Distribution (SFT)
                                      \
                                       ~~> Output state ~~> Token Distribution (Policy)
                                                       \
                                                        ~~> Value              (Value Head)
~~> Not frozen
==> Frozen

Reproduction

I made a fork with a metric function that uses the reference model to compute the sum of logits for a static dummy input: Apsod@a7e6fc6

Since the SFT-model should be frozen, this reference logit sum should be constant (or vary slightly due to dropout), but this is not the case when running the examples/ppo_sentiments.py-script. I then explicitly froze the embedding layers in freeze_bottom_causal_layers, which resulted in constant reference logit sum Apsod@7c91b45

image

Which trlX version are you using?

Commit hash: 9c83cee

Additional system and package information

No response

@cat-state Have you gotten a chance to look at this yet?

Thanks, this should be fixed by: #543