TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using batch_first=True for torch native transformer blocks seem to result in non-complete summaries

cwestergren opened this issue · comments

Describe the bug
Using batch_first=True for torch native transformer blocks seem to result in non-complete summaries. When setting batch_ffirst=False, it looks good.

To Reproduce
Steps to reproduce the behavior:

encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=3072, activation='gelu', batch_first=True, norm_first=True) # For some weird reason, the transformer block with batch first fails.
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
summary(transformer_encoder, input_size=(1,197,768), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

will correctly print out the below summary.

Layer (type:depth-idx) Input Shape Output Shape Param # Trainable

TransformerEncoder [1, 197, 768] [1, 197, 768] -- True
├─ModuleList: 1-1 -- -- -- True
│ └─TransformerEncoderLayer: 2-1 [1, 197, 768] [1, 197, 768] -- True
│ │ └─LayerNorm: 3-1 [1, 197, 768] [1, 197, 768] 1,536 True
│ │ └─MultiheadAttention: 3-2 [1, 197, 768] [1, 197, 768] 2,362,368 True
│ │ └─Dropout: 3-3 [1, 197, 768] [1, 197, 768] -- --
│ │ └─LayerNorm: 3-4 [1, 197, 768] [1, 197, 768] 1,536 True
│ │ └─Linear: 3-5 [1, 197, 768] [1, 197, 3072] 2,362,368 True
│ │ └─Dropout: 3-6 [1, 197, 3072] [1, 197, 3072] -- --
│ │ └─Linear: 3-7 [1, 197, 3072] [1, 197, 768] 2,360,064 True
│ │ └─Dropout: 3-8 [1, 197, 768] [1, 197, 768] -- --

Using batch_first=True i get another behaviour,

Layer (type:depth-idx) Input Shape Output Shape Param # Trainable

TransformerEncoder [1, 197, 768] [1, 197, 768] -- True
├─ModuleList: 1-1 -- -- -- True
│ └─TransformerEncoderLayer: 2-1 [1, 197, 768] [1, 197, 768] 7,087,872 True

Expected behavior
I would expect that the printout would be the "the same".

Desktop (please complete the following information):

  • Google Colab on Edge.