RNN input size question
OrangeC93 opened this issue · comments
OrangeC93 commented
I'm new to pytorch, Can anyone answer my question which confused me a lot:
images are reshaped into (batch, seq_len, input_size)
images = images.reshape(-1, sequence_length, input_size)
But What I learned input dimensions should be (seq_len, batch, input_size)?
GatoY commented
Hi @OrangeC93,
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
See "batch_first = True". But in default "batch_first = False".
That's the reason. U can refer to the source code of RNN cell. here
OrangeC93 commented
Hi @OrangeC93,
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
See "batch_first = True". But in default "batch_first = False".
That's the reason. U can refer to the source code of RNN cell. here
wow~ I got it! Many thanks!