jannerm / ddpo

Code for the paper "Training Diffusion Models with Reinforcement Learning"

Home Page:https://rl-diffusion.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Log prob Computation

anschen1994 opened this issue · comments

In the line ddpo/diffusers_patch/scheduling_ddim_flax.py:359, use the formula jnp.mean to total log prob for the total latent space.
Is more proper to use formula jnp.sum here?

Since jnp.mean means we compute the ratio as,
$r = \left(\frac{\sum_{i=1}^N p_i}{\sum_{i=1}^N q_i}\right)^{\frac{1}{N}}$
where $N$ is total dimension of latent space, $p_i$ is the new probability and $q_i$ is the old probability.

And in the default config, clip range is set to 1e-4, which means the true clip range is $10^{-4N}$, which is very very small, compared to other reinforcement learning application.

Excellent question! I actually implemented it as jnp.sum first, and immediately got NaNs after the first training step. I quickly realized this is due to the unusually high dimensionality of our action space. Typical applications of PPO (e.g. for control) have $N < 100$, whereas our actions are entire latent images, so $N = 64 \times 64 \times 4 = 16,384$. That means if our current jnp.mean implementation produces a ratio of just $r = 1.01$ for a given action, switching to jnp.sum would produce a ratio of $1.01^{16,384} \approx 6.33 \times 10^{70}$ (way above the maximum value for float32). That means our "true" clip range is actually much larger. If currently we clip to the ratio $(0.9999, 1.0001)$, we're clipping the "true" ratio to $(0.9999^{16,384}, 1.0001^{16,384}) \approx (0.19, 5.15)$.

However, I think this is about more than numerical stability. Even if our numerical format could support arbitrarily large (or small) density ratios, that really doesn't seem like what we would want. Unfortunately, I can't really give a better intuitive or theoretical justification than this, but it seems like it makes sense to clip based on the "average" change per pixel rather than the total change of the image. It might also have to do with the fact that we model the policy as an isotropic Gaussian, where every pixel is independent, but in reality policy updates are very highly correlated across pixels. FWIW this has probably been studied by somebody in the past, I'm just not familiar with the relevant literature.

Clear explanation~ Thanks!