Attention mask when calculating log ratio for PPO
kmy17518 opened this issue · comments
Minyeong Kim commented
Hi, I have a quesiton about calculating log ratio for PPO.
I'm very new to this area and I would be really grateful if you can help me.
In accelerate_ppo_trainer.py, def make_experience, line 457
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
but according to # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled,
so shouldn't it be attention_mask[:, 1:] ?
in accelerate_ppo_trainer.py, def loss, line 188
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
attention_mask[:, start + 1 : end + 1],
)
Here I think attention mask is shifted the correct way. So why is it different in def make_experience?
Thanks for your help in advance!