Log prob Computation
anschen1994 opened this issue · comments
In the line ddpo/diffusers_patch/scheduling_ddim_flax.py:359
, use the formula jnp.mean
to total log prob for the total latent space.
Is more proper to use formula jnp.sum
here?
Since jnp.mean
means we compute the ratio as,
where
And in the default config, clip range is set to 1e-4, which means the true clip range is
Excellent question! I actually implemented it as jnp.sum
first, and immediately got NaNs after the first training step. I quickly realized this is due to the unusually high dimensionality of our action space. Typical applications of PPO (e.g. for control) have jnp.mean
implementation produces a ratio of just jnp.sum
would produce a ratio of
However, I think this is about more than numerical stability. Even if our numerical format could support arbitrarily large (or small) density ratios, that really doesn't seem like what we would want. Unfortunately, I can't really give a better intuitive or theoretical justification than this, but it seems like it makes sense to clip based on the "average" change per pixel rather than the total change of the image. It might also have to do with the fact that we model the policy as an isotropic Gaussian, where every pixel is independent, but in reality policy updates are very highly correlated across pixels. FWIW this has probably been studied by somebody in the past, I'm just not familiar with the relevant literature.
Clear explanation~ Thanks!