jannerm / ddpo

Code for the paper "Training Diffusion Models with Reinforcement Learning"

Home Page:https://rl-diffusion.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about finetune.py

chaojiewang94 opened this issue · comments

The only difference between the method used in finetune.py and trandtional training of diffusion model is that the former multiplies batch-level weights (normalized rewards) to the batch-level reconstruction loss provied by the Unet in stable diffusion, is that true?

is the following code in diffusion.py the key point of success of the REINFORCE version of your method?
if weights is None:
## average over batch dimension
loss = loss.mean()
else:
## multiply loss by weights
assert loss.size == weights.size
loss = (loss * weights).sum()

It looks like you might be looking at the RWR train step defined in training/diffusion.py. The DDPO update step is defined in training/policy_gradient.py:L63.

Thanks, I am reproducing your work with GPU and almost finish it. I am one more question about your work.

What is the difference bettween policy-gradient verision of DDPO and GAN (maxmizing weighted likelihood)

Assume we have a fixed discriminator r(x_0), which is also treated as a reward model.
The obective function of the policy-gradient verision of DDPO and GAN should be the same, formulated as
E_x~p_\theta(x)[r(x_0)], the only difference is how to estimate or optimize this function, is that correct?

There is indeed a connection between policy gradient algorithms and GANs, though it requires a little more machinery. I recommend checking out @caseychu's great paper on Probability Functional Descent for the full story.