LSTM weights should have separate orthogonal initializations for each gate
Jammf opened this issue · comments
James Mochizuki-Freeman commented
Problem Description
The LSTM weight matrices in ppo_atari_lstm.py
seem to be be initialized incorrectly, if the goal is to have a separate orthogonal matrix for each gate. Since lstm.weight_ih_l0
and lstm.weight_hh_l0
have the four gate matricies concatenated together, shouldn't each of the four parts of the fused weight matrix be separately initialized to an orthogonal matrix?
Checklist
-
I have installed dependencies viapoetry install
(see CleanRL's installation guideline. - I have checked that there is no similar issue in the repo.
- I have checked the documentation site and found not relevant information in GitHub issues.
Current Behavior
As a minimal example, checking just the
import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0, 1.0)
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05) # check that W_hi is orthogonal
# -> False
Expected Behavior
import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0[:128], 1.0) # init a view with only W_hi
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05) # check that W_hi is orthogonal
# -> True
Possible Solution
self.lstm = nn.LSTM(512, 128)
for name, param in self.lstm.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
- nn.init.orthogonal_(param, 1.0)
+ nn.init.orthogonal_(param[:128], 1.0)
+ nn.init.orthogonal_(param[128:128*2], 1.0)
+ nn.init.orthogonal_(param[128*2:128*3], 1.0)
+ nn.init.orthogonal_(param[128*3:], 1.0)