CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Attention mask when calculating log ratio for PPO

kmy17518 opened this issue · comments

Hi, I have a quesiton about calculating log ratio for PPO.
I'm very new to this area and I would be really grateful if you can help me.

In accelerate_ppo_trainer.py, def make_experience, line 457
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]

but according to # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled,
so shouldn't it be attention_mask[:, 1:] ?

in accelerate_ppo_trainer.py, def loss, line 188

logprobs, values_pred, mask = (
                logprobs[:, start:end],
                values_pred[:, start:end],
                attention_mask[:, start + 1 : end + 1],
            )

Here I think attention mask is shifted the correct way. So why is it different in def make_experience?

Thanks for your help in advance!