kvablack / ddpo-pytorch

DDPO for finetuning diffusion models, implemented in PyTorch with LoRA support

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Prompt Alignment with LLaVA-server: Client-side prompt and image doesn't match server side reward

desaixie opened this issue · comments

I am running the prompt alignment experiment with LLaVA-server, although I am using BLIP2 instead of LLaVA.
I wanted to see the VLM's caption of the image along side the prompt, image, and reward, so I added this additional logging to wandb. For passing the caption strings from the server back to the main training process, I converted the fixed-length strings into ascii integers with ord(), so it can be converted to a torch.tensor before calling accelerator.gather at this line, and then back to strings with chr(). As shown in the image below, the prompts and VLM captions that I received from the server do not match.
image

Then I used a trick trying to match the input prompt from the client side and the server's response. For each prompt generated with prompt_fn, I generate a random 5-digit id number. This id is passed to the server, prepended to the VLM's outputs. Then I use the prompts' ids to retrieve the corresponding captions. As shown below, the prompts and the captions now match after using my "id" trick. I also appended the computed rewards to the captions on the server side, before sending the response to client. However, the reward appended at the end of the captions do not match the rewards from the client side (code). It seems that the server's responses don't preserver the order of the queries it receives.
image

Could you verify if the current code does have this problem where the order of server's responses doesn't match that of the client's queries? I am getting clear training progress, which shouldn't be the case the the rewards' order is messed up.
image

Thanks for the detailed explanation! I don't think this has anything to do with the LLaVA server, but instead just sloppy coding on my part. When the wandb logging happens on this line, the prompts variable (as well as the images variable) comes from the last iteration of the for loop above. This corresponds to one sample batch from one process. However, the rewards variable contains all of the rewards gathered across all sample batches and processes. This is why the prompt and image always match, but the reward and reward metadata are totally wrong.

Do you think you could take a shot at fixing this? You don't actually need to gather the prompts or reward_metadata across processes. The wandb logging only happens from one process (process 0), so you can just log the images, prompts, rewards, and reward_metadata from a single sample batch from process 0.

Are you saying that the non-matching-order problem only affects the wandb logging, meaning that during training the advantages, log_probs, etc. still have the correct order? Then I guess the fix is to use log the rewards of process 0 without calling accelerator.gather()?

Yes, this only affects the logging. I was already only logging the prompts and images from process 0, so the fix would be to also log the corresponding rewards and metadata from process 0. If you get it working, please feel free to open a pull request!

Closed by #9