yunjey / pytorch-tutorial

PyTorch Tutorial for Deep Learning Researchers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RNN input size question

OrangeC93 opened this issue · comments

I'm new to pytorch, Can anyone answer my question which confused me a lot:

In RNN tutorial

images are reshaped into (batch, seq_len, input_size)

images = images.reshape(-1, sequence_length, input_size)

But What I learned input dimensions should be (seq_len, batch, input_size)?

commented

Hi @OrangeC93,

self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

See "batch_first = True". But in default "batch_first = False".

That's the reason. U can refer to the source code of RNN cell. here

Hi @OrangeC93,

self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

See "batch_first = True". But in default "batch_first = False".

That's the reason. U can refer to the source code of RNN cell. here

wow~ I got it! Many thanks!