MarcoMeter / episodic-transformer-memory-ppo

Clean baseline implementation of PPO using an episodic TransformerXL memory

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Questions about the implementation

hai-h-nguyen opened this issue · comments

Hi Marco,

I understand the need for the memory for the first transformer block, but I would like some help understanding the need for later memory blocks.

Given the first block containing past episodes, you can sample a bunch of episodes and use them as the input to update the policy and the value function without remembering intermediate results (such as the input for the second transformer block of those episodes in the past).

Another thing is for the first memory block, why don't you store raw observations instead of the embedded observations? A potential issue is that when updating the policy and the value function, the embedding layers will be updated (which can be very different from those used during data collection). This might make the off-policyness worse.

Hi!

We need the external memory for all transformer blocks, because we utilize the transformer as a sequence-to-one model.

We don't use raw observations because of a much larger memory overhead. This implementation is already costly. Because of utilizing past activations TrXL has an infinite context length. This would not be possible with raw observations.

Hi, thanks for the answer. Without the external memory, I think the transformer is still a sequence-to-one model. For TrXL, it seems the authors had to use relative positional encoding to train it properly, but you make it optional. Can you explain that part?

Also, you seem to constrain the keys and the values to be the same for the TransformerBlock, as in https://github.com/MarcoMeter/episodic-transformer-memory-ppo/blob/main/transformer.py#L249. Is there any reason for that?

Thank you!

Without the external memory, I think the transformer is still a sequence-to-one model

Then you would need a forward loop in the forward pass that severely hurts performance.

Positional encoding is only needed if the order of time is relevant. In Minigrid Memory, there is one cue to be memorized. However, it is not important when the cue was observed.

Is there any reason for that?

That's the concept of self-attention.

Hi Marco,

Out of curiosity, did you benchmark against the implementation of PPO-Attention of the RLLib? https://github.com/ray-project/ray/blob/master/rllib/examples/attention_net.py.

I struggled with Rllib. I got weird outputs when running on Memory Gym environments.