Hidden dim of LSTM history summarization module must be equal to observation dim
antoine-galataud opened this issue · comments
For now, it's impossible to configure LSTMHistorySummarizationModule
with hidden_dim
other than observation space dimension. If you try to, this leads to following exception (PPO example):
Error
Traceback (most recent call last):
File "/home/antoine/git/Pearl/pearl/test/integration/integration_tests.py", line 390, in test_ppo_lstm
target_return_is_reached(
File "/home/antoine/git/Pearl/pearl/utils/functional_utils/train_and_eval/online_learning.py", line 192, in target_return_is_reached
episode_info, episode_total_steps = run_episode(
File "/home/antoine/git/Pearl/pearl/utils/functional_utils/train_and_eval/online_learning.py", line 250, in run_episode
action = agent.act(exploit=exploit)
File "/home/antoine/git/Pearl/pearl/pearl_agent.py", line 154, in act
action = self.policy_learner.act(
File "/home/antoine/git/Pearl/pearl/policy_learners/sequential_decision_making/actor_critic_base.py", line 208, in act
action_probabilities = self._actor.get_policy_distribution(
File "/home/antoine/git/Pearl/pearl/neural_networks/sequential_decision_making/actor_networks.py", line 135, in get_policy_distribution
policy_distribution = self.forward(
File "/home/antoine/git/Pearl/pearl/neural_networks/sequential_decision_making/actor_networks.py", line 119, in forward
return self._model(x)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/antoine/git/Pearl/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 4x64)
They could be decoupled by using a linear layer with output dim = observation space dim after the LSTM module.
Hi there, sorry about the late reply. We all just got back from our holiday break.
The hidden dimension is supposed to be the state representation dimension for policy learner so this is by design. With history summarization module, the observation_dim is basically the state_dim generated by the history summarization module. Hope this helps.
It makes sense, thanks. I believe now adding an indirection between the history summarization network and the policy network isn't necessary nor wanted. A hint could help though, to avoid configuration mistakes (ie assert lstm hidden dim = policy state dim
)