ml-jku / helm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

input observations for CNN and PLT

alstn7 opened this issue · comments

Hello, thanks for the great repo!

I was reading your paper, and the overall architecture (figure1) shows how three time stamp obervations (o_t, o_t-1, o_t-2) are inputted to the pretrained language transformer (PLT) and only the current state's observation (o_t) to the CNN part. Then, the feautures extracted from both parts are concated and fed into the actor-critic head.

However, I noticed in the Algorithm (denoted Algorithm 1 in the paper), only the current observation (o_t) is inputted to PLT like how the CNN part. Your code also seems to follow this rule:
model,py, line 142 - 143
obs_query = self.query_encoder(observations) vocab_encoding = self.frozen_hopfield.forward(observations)
self.query_encoder referring to the CNN part, and self.frozen_hopfield obviously the frozenhopfield algorithm.

Could you plese clarify on this? Should I understand that to the PLT part, only the current observation is inputted, just like the CNN part?

edit
Maybe what you mean by the three time stamp obervations (o_t, o_t-1, o_t-2) of figure 1 is specific to Transformer XL and how it saves memory from the previous segment's hidden state for long-term dependency?
And if so, theoretically, would you say using other pretrained language model like BERT that does not utilize previous segment as memory does not take previous time stamp observations (o_t-1, o_t-2) into account?
Thus, such model (TrXL replaced by BERT) would be disregarding past observations when learning since both PLT and CNN part only take o_t as input?

Sorry for such long question.
Thanks!

Hi!

Figure 1 in the paper should simply illustrate that the PLT receives a sequence of observations whereas the CNN operates only on the current timestep.
You are correct, the way we implemented it is specific for TransformerXL which carries the memory register of past observations that store the activations of the previous timesteps.
The memory register is tracked in helm_trainer.py and passed to the TrXL here:

self.policy.memory = self._last_mems

This is very neat, because you only pass the current timestep, but perform attention over the entire sequence of observations.
For BERT, for example, you would need to keep track of the entire observation sequence and propagate those through BERT.
As accumulated history representation you could then take the output of the CLS token, which performs attention across the entire sequence.

Thank you for your clarification!