facebookresearch / Pearl

A Production-ready Reinforcement Learning AI Agent Library brought by the Applied Reinforcement Learning team at Meta.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Hidden dim of LSTM history summarization module must be equal to observation dim

antoine-galataud opened this issue · comments

For now, it's impossible to configure LSTMHistorySummarizationModule with hidden_dim other than observation space dimension. If you try to, this leads to following exception (PPO example):

Error
Traceback (most recent call last):
  File "/home/antoine/git/Pearl/pearl/test/integration/integration_tests.py", line 390, in test_ppo_lstm
    target_return_is_reached(
  File "/home/antoine/git/Pearl/pearl/utils/functional_utils/train_and_eval/online_learning.py", line 192, in target_return_is_reached
    episode_info, episode_total_steps = run_episode(
  File "/home/antoine/git/Pearl/pearl/utils/functional_utils/train_and_eval/online_learning.py", line 250, in run_episode
    action = agent.act(exploit=exploit)
  File "/home/antoine/git/Pearl/pearl/pearl_agent.py", line 154, in act
    action = self.policy_learner.act(
  File "/home/antoine/git/Pearl/pearl/policy_learners/sequential_decision_making/actor_critic_base.py", line 208, in act
    action_probabilities = self._actor.get_policy_distribution(
  File "/home/antoine/git/Pearl/pearl/neural_networks/sequential_decision_making/actor_networks.py", line 135, in get_policy_distribution
    policy_distribution = self.forward(
  File "/home/antoine/git/Pearl/pearl/neural_networks/sequential_decision_making/actor_networks.py", line 119, in forward
    return self._model(x)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 4x64)

They could be decoupled by using a linear layer with output dim = observation space dim after the LSTM module.

Hi there, sorry about the late reply. We all just got back from our holiday break.

The hidden dimension is supposed to be the state representation dimension for policy learner so this is by design. With history summarization module, the observation_dim is basically the state_dim generated by the history summarization module. Hope this helps.

It makes sense, thanks. I believe now adding an indirection between the history summarization network and the policy network isn't necessary nor wanted. A hint could help though, to avoid configuration mistakes (ie assert lstm hidden dim = policy state dim)