XiangLi1999 / Diffusion-LM

Diffusion-LM

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Losses for E2E Training

zanussbaum opened this issue · comments

Hi, thanks for releasing the code! I had a quick question about the different loss functions in the code.

I'm trying to wrap my head around the loss function presented in the paper and compare it to what's in the code. I'm taking a look at the function
Le2esimple(w) = Eqφ(x0:T |w) Lsimple(x0) + ||EMB(w) − µθ(x1, 1)||2 − log pθ(w|x0)

LSimple appears to be this line

The loss between the embeddings seem to be these lines

And the cross entropy loss between the logits and input tokens appears to be here

However, I'm a little confused on what these lines account for. From my debugging, this just seems to be taking the embeddings multiplied with noise, multiplied with sqrt_alphas_cumprod across all timesteps.

Am I misinterpreting what's in the code versus what's in the paper?

Hi,

Thanks for the great question!

We are simplifying a bit in the main text of the paper, and you can regard this term tT_loss as part of Lsimple when t=T. It's a bit misuse of notation to simplify the main text: that by assumption \mu_theta(x_T, T)= 0 since p(x_T) = N(0, I), so this Lsimple term at t=T reduces to a L2 norm of x_T.

The Lsimple term is derived from Lvlb term, and we put a detailed derivation in appendix E (page 17),
image and this line in particular justifies the term tT_loss.

Intuitively, this term exists to avoid the embedding norms from being too large.

Thanks for the great explanation! I missed that part originally when looking over the paper but intuitively makes sense to have a L2 norm of x_T!