TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nn.ParameterList omitted again in v1.7.1

github-0-searcher opened this issue · comments

Hi there.
I was trying to inspect mmoe model from mmoe, which has nn.PrameterList.

class Expert(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Expert, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out


class Tower(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Tower, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.4)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

class MMOE(nn.Module):
    def __init__(self, input_size, num_experts, experts_out, experts_hidden, towers_hidden, tasks):
        super(MMOE, self).__init__()
        self.input_size = input_size
        self.num_experts = num_experts
        self.experts_out = experts_out
        self.experts_hidden = experts_hidden
        self.towers_hidden = towers_hidden
        self.tasks = tasks

        self.softmax = nn.Softmax(dim=1)

        self.experts = nn.ModuleList([Expert(self.input_size, self.experts_out, self.experts_hidden) for i in range(self.num_experts)])
        self.w_gates = nn.ParameterList([nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True) for i in range(self.tasks)])
        self.towers = nn.ModuleList([Tower(self.experts_out, 1, self.towers_hidden) for i in range(self.tasks)])

    def forward(self, x):
        experts_o = [e(x) for e in self.experts]
        experts_o_tensor = torch.stack(experts_o)

        gates_o = [self.softmax(x @ g) for g in self.w_gates]

        tower_input = [g.t().unsqueeze(2).expand(-1, -1, self.experts_out) * experts_o_tensor for g in gates_o]
        tower_input = [torch.sum(ti, dim=0) for ti in tower_input]

        final_output = [t(ti) for t, ti in zip(self.towers, tower_input)]
        return final_output

model = MMOE(input_size=499, num_experts=6, experts_out=16, experts_hidden=32, towers_hidden=8, tasks=2)

torchinfo.summary(model, input_size=(1024, 499),
                  col_names=[
                      "kernel_size", 
                      "input_size",
                      "output_size", 
                      "num_params", 
                      "trainable",
                      "mult_adds"
                      ],
                  col_width=16,
                  row_settings=["var_names", "depth"],
                  )

I was on v1.7.1 and I got something like this.

========================================================================================================================================
Layer (type (var_name):depth-idx)        Kernel Shape     Input Shape      Output Shape     Param #          Trainable        Mult-Adds
========================================================================================================================================
MMOE (MMOE)                              --               [1024, 499]      [1024, 1]        5,988            True             --
├─ModuleList (experts): 1-1              --               --               --               --               True             --
│    └─Expert (0): 2-1                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-1            --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-2             --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-3       --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-4            --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (1): 2-2                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-5            --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-6             --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-7       --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-8            --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (2): 2-3                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-9            --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-10            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-11      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-12           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (3): 2-4                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-13           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-14            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-15      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-16           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (4): 2-5                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-17           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-18            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-19      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-20           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (5): 2-6                   --               [1024, 499]      [1024, 16]       --               True             --
│    │    └─Linear (fc1): 3-21           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-22            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-23      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-24           --               [1024, 32]       [1024, 16]       528              True             540,672
├─Softmax (softmax): 1-2                 --               [1024, 6]        [1024, 6]        --               --               --
├─Softmax (softmax): 1-3                 --               [1024, 6]        [1024, 6]        --               --               --
├─ModuleList (towers): 1-4               --               --               --               --               True             --
│    └─Tower (0): 2-7                    --               [1024, 16]       [1024, 1]        --               True             --
│    │    └─Linear (fc1): 3-25           --               [1024, 16]       [1024, 8]        136              True             139,264
│    │    └─ReLU (relu): 3-26            --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Dropout (dropout): 3-27      --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Linear (fc2): 3-28           --               [1024, 8]        [1024, 1]        9                True             9,216
│    │    └─Sigmoid (sigmoid): 3-29      --               [1024, 1]        [1024, 1]        --               --               --
│    └─Tower (1): 2-8                    --               [1024, 16]       [1024, 1]        --               True             --
│    │    └─Linear (fc1): 3-30           --               [1024, 16]       [1024, 8]        136              True             139,264
│    │    └─ReLU (relu): 3-31            --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Dropout (dropout): 3-32      --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Linear (fc2): 3-33           --               [1024, 8]        [1024, 1]        9                True             9,216
│    │    └─Sigmoid (sigmoid): 3-34      --               [1024, 1]        [1024, 1]        --               --               --
========================================================================================================================================
Total params: 105,446
Trainable params: 105,446
Non-trainable params: 0
Total mult-adds (M): 101.84
========================================================================================================================================
Input size (MB): 2.04
Forward/backward pass size (MB): 2.51
Params size (MB): 0.40
Estimated Total Size (MB): 4.95
========================================================================================================================================

This seems to be great, nearly all things are included. But nn.ParameterList (w_gates) is omitted.
I went through #54 and #84.
Seems to be mentioned before and I downgraded it to v1.7.0

I got result as follow, which includes nn.ParameterList, but result itself seems to be incorrect?

========================================================================================================================================
Layer (type (var_name):depth-idx)        Kernel Shape     Input Shape      Output Shape     Param #          Trainable        Mult-Adds
========================================================================================================================================
MMOE (MMOE)                              --               [1024, 499]      [1024, 1]        --               True             --
├─Softmax (softmax): 1-6                 --               [1024, 6]        [1024, 6]        --               --               --
├─ModuleList (experts): 1-2              --               --               --               16,528           True             --
│    └─Expert (0): 2-1                   --               [1024, 499]      [1024, 16]       16,528           True             --
│    │    └─Linear (fc1): 3-2            --               [1024, 499]      [1024, 32]       (recursive)      True             16,384,000
│    │    └─Linear (fc1): 3-2            --               [1024, 499]      [1024, 32]       (recursive)      True             16,384,000
│    │    └─ReLU (relu): 3-3             --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-4       --               [1024, 32]       [1024, 32]       --               --               --
│    └─Expert (1): 2-3                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (0): 2                     --               --               --               --               --               --
│    │    └─Linear (fc2): 3-5            --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (1): 2-3                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    │    └─Linear (fc1): 3-6            --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-7             --               [1024, 32]       [1024, 32]       --               --               --
│    └─Expert (2): 2-5                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (1): 2                     --               --               --               --               --               --
│    │    └─Dropout (dropout): 3-8       --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-9            --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (2): 2-5                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    │    └─Linear (fc1): 3-10           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    └─Expert (3): 2-7                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (2): 2                     --               --               --               --               --               --
│    │    └─ReLU (relu): 3-11            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-12      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-13           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (3): 2-7                   --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (4): 2-10                  --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (3): 2                     --               --               --               --               --               --
│    │    └─Linear (fc1): 3-14           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-15            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-16      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-17           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (5): 2-12                  --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    └─Expert (4): 2-10                  --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    │    └─Linear (fc1): 3-18           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-19            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-20      --               [1024, 32]       [1024, 32]       --               --               --
├─ParameterList (w_gates): 1-3           --               --               --               5,988            True             --
├─ModuleList (towers): 1-4               --               --               --               --               True             --
│    └─Tower (0): 2-13                   --               [1024, 16]       [1024, 1]        (recursive)      True             --
├─ModuleList (experts): 1-2              --               --               --               16,528           True             --
│    └─Expert (4): 2                     --               --               --               --               --               --
│    │    └─Linear (fc2): 3-21           --               [1024, 32]       [1024, 16]       528              True             540,672
│    └─Expert (5): 2-12                  --               [1024, 499]      [1024, 16]       (recursive)      True             --
│    │    └─Linear (fc1): 3-22           --               [1024, 499]      [1024, 32]       16,000           True             16,384,000
│    │    └─ReLU (relu): 3-23            --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Dropout (dropout): 3-24      --               [1024, 32]       [1024, 32]       --               --               --
│    │    └─Linear (fc2): 3-25           --               [1024, 32]       [1024, 16]       528              True             540,672
├─Softmax (softmax): 1-5                 --               [1024, 6]        [1024, 6]        --               --               --
├─ModuleList (towers): 1-4               --               --               --               --               True             --
│    └─Tower (1): 2                      --               --               --               --               --               --
│    │    └─Linear (fc2): 3-35           --               [1024, 8]        [1024, 1]        (recursive)      True             9,216
├─Softmax (softmax): 1-6                 --               [1024, 6]        [1024, 6]        --               --               --
├─ModuleList (towers): 1-4               --               --               --               --               True             --
│    └─Tower (0): 2-13                   --               [1024, 16]       [1024, 1]        (recursive)      True             --
│    │    └─Linear (fc1): 3-27           --               [1024, 16]       [1024, 8]        136              True             139,264
│    │    └─ReLU (relu): 3-28            --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Dropout (dropout): 3-29      --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Linear (fc2): 3-30           --               [1024, 8]        [1024, 1]        9                True             9,216
│    │    └─Sigmoid (sigmoid): 3-31      --               [1024, 1]        [1024, 1]        --               --               --
│    └─Tower (1): 2-14                   --               [1024, 16]       [1024, 1]        9                True             --
│    │    └─Linear (fc1): 3-32           --               [1024, 16]       [1024, 8]        136              True             139,264
│    │    └─ReLU (relu): 3-33            --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Dropout (dropout): 3-34      --               [1024, 8]        [1024, 8]        --               --               --
│    │    └─Linear (fc2): 3-35           --               [1024, 8]        [1024, 1]        (recursive)      True             9,216
│    │    └─Sigmoid (sigmoid): 3-36      --               [1024, 1]        [1024, 1]        --               --               --
========================================================================================================================================
Total params: 105,446
Trainable params: 105,446
Non-trainable params: 0
Total mult-adds (M): 118.24
========================================================================================================================================
Input size (MB): 2.04
Forward/backward pass size (MB): 2.24
Params size (MB): 0.36
Estimated Total Size (MB): 4.64
========================================================================================================================================

Hi, the reason for this change is 2-fold.

First, there was a recent change in recording module layers for result. After this change, it only records modules that are called (via their respective forward function).

The second reason is that torchinfo as of now does not record torch functions (such as + or @).

In your case, none of the parameters inside ParameterList is called (with forward function) and but used only for @ which is not recorded by torchinfo.

Regarding this, there is a new discussion (by me) about the new tracing mechanism that I think would resolve this problem. If you are interested in helping implementation of this, please see here #192 .

Thanks for notification. I will check that out.

Hello. It seems to still be ommited