uber-research / PPLM

Plug and Play Language Model implementation. Allows to steer topic and attributes of GPT-2 models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about future tokens in perturbation

yumoxu opened this issue · comments

Hi there,

Thanks again for your great work and kind response last time.

I have another question about this loop for obtaining future token representations. Let's say the current generation step is at t. At the start, we have unperturbed past about [0, t-1], and we perturb it. Then, we use the perturbed past (about [0, t-1]), and the last generated token from t-1, to generate the hidden state at t (see this line).

We then use the hidden state at t to obtain the input embedding for t+1, and finally, the future hidden state at t+1. To do this, you run the forward pass again in the loop, based on

  • Unperturbed past, about [0, t-1] and
  • Input embedding at t+1

What seems missing to me in this forward pass is the generated token at t-1. The past you use here is about [0, t-1], and the next input, from the autoregressive perspective, should be the generated token at t-1. However, the input embedding at t+1, which contains information from the step t is applied in your code.

I personally don't think there's anything wrong with this implementation, but I would like to confirm if I have misunderstood anything here. Also, if my understanding is right, it would be great if you could elaborate on the motivation of choosing not to encode the generated token at t-1 explicitly. For instance, an easy way would be to append the generated token at t-1 to the past, then do the forward pass.

Thanks a lot!

Sorry, the paper doesn't do the best job at explaining this bit.

Lines 199-204 sum up the perturbed latents (from the last layer) from 0 to t (since these tokens are fixed). To these perturbed latents, we add the future hidden-state at t+1 -- this is based on a forward pass with the unpert_past as defined here(which includes embeddings from 0 to t) and the future distribution for the token at t+1 -- we do encode the token at t-1 explicitly already in 566 to produce the embedding at t.

We found that this gave us more reliable gradients rather than using future embeddings generated using the perturbed latents (hence the duplicated forward pass) -- likely because the discriminator itself is trained on unperturbed representations. Alternatively, you could also use the unperturbed latents from 0 to t in the term that goes into the discriminator here and that may work slightly better.

Many thanks for your prompt reply!

After double-checking the codes, I think the two tensors I got confused with were past and unpert_past which the function perturb_past() takes as inputs.

My understanding now is past contains information about [0, t-1] (defined here), and it is what the perturbation is based on. And unperturb_past, which contains information about [0, t], is specifically for obtaining future token representations where you want no perturbation involved for more stable gradients.

Can you please confirm this?

Besides, this is an interesting idea:

Alternatively, you could also use the unperturbed latents from 0 to t in the term that goes into the discriminator here and that may work slightly better.

If my understanding is right, new_accumulated_hidden in this line is the summation of

  1. the unperturbed hidden states in [0, t-1]
  2. the hidden state at t based on perturbed past
  3. the hidden state at t+1 based on unperturb_past and the output of at t

So by

use the unperturbed latents from 0 to t

Do you mean one can possibly replace the 2nd item with the hidden state at t based on unperturb_past?

If so, the only component that delivers the effects of perturbation is the hidden state at t+1, via the output at t.
Indeed, we can anticipate more stable gradients.

Can you please confirm this?

Yes. That is correct.

Do you mean one can possibly replace the 2nd item with the hidden state at t based on unperturb_past?

Yes. My main intuition behind this is that the discriminator is trained on unperturbed representations, so it makes sense to use only the unperturbed representations to obtain gradients.

My main intuition behind this is that the discriminator is trained on unperturbed representations, so it makes sense to use only the unperturbed representations to obtain gradients.

It makes sense to me too! I will give it a try. Thanks again for your clarification!