kvablack / ddpo-pytorch

DDPO for finetuning diffusion models, implemented in PyTorch with LoRA support

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support for other schedulers

desaixie opened this issue · comments

This code currently only supports DDIM. In the recently released SD-XL, the default scheduler is EulerDiscrete. From the paper and the code, it seems that the prev_sample is no longer sampled from a Gaussian distribution but a ODE solution instead (correct me if I am wrong here). How to calculate the log_prob of prev_sample given the noise_pred in this case?

I'm not very familiar with other schedulers, which is why we stuck with DDIM. I'd imagine it's possible to compute a policy gradient with respect to 2nd order samplers (such as EulerDiscrete), but it's going to require some in-depth knowledge of the math (which I frankly don't have). Sounds like a great direction for future work!