tcbegley / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[nanoChatGPT] How to represent reward model

tcbegley opened this issue · comments

The reward model is trained on proposed answers to a prompt which come in pairs, one marked as chosen, the other as rejected. The reward model should output a high score on the chosen answers, and a low score on the rejected.

It seems tricky to come up with a clean programming pattern for this using tensorclasses. Ideally it would be nice to represent the data using a tensorclass, and use TensorDictModule to perform a single forward pass on the data.

We have a tensorclass roughly of the form

@tensorclass
class Data:
    prompt: torch.Tensor
    chosen: torch.Tensor
    rejected: torch.Tensor

We need to do two forward passes, subtract the results and backpropagate. So we end up doing something roughly like this

chosen_loss = model(batch.prompt, batch.chosen)
rejected_loss = model(batch.prompt, batch.rejected)
loss = -torch.sigmoid(chosen_loss - rejected_loss)

which doesn't make use of TensorDictModule. One possibility would be to do something like

chosen_model = TensorDictModule(model, ["prompt", "chosen"], ["chosen_loss"])
rejected_model = TensorDictModule(model, ["prompt", "rejected"], ["rejected_loss"])
chosen_model(batch)
rejected_model(batch)
loss = -torch.sigmoid(batch.chosen_loss - batch.rejected_loss)

We could even then combine these into a single call with TensorDictSequential. The only problem is that this feels more complicated and hard to follow.

Similarly we could combine the forward passes of chosen and rejected examples into a single forward pass by adding in a flag which indicates the sign to be used for that example when aggregating the scores, but similarly that becomes more complex and hard to follow.