nn.ParameterList omitted again in v1.7.1
github-0-searcher opened this issue · comments
Hi there.
I was trying to inspect mmoe model from mmoe, which has nn.PrameterList.
class Expert(nn.Module):
def __init__(self, input_size, output_size, hidden_size):
super(Expert, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
return out
class Tower(nn.Module):
def __init__(self, input_size, output_size, hidden_size):
super(Tower, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
class MMOE(nn.Module):
def __init__(self, input_size, num_experts, experts_out, experts_hidden, towers_hidden, tasks):
super(MMOE, self).__init__()
self.input_size = input_size
self.num_experts = num_experts
self.experts_out = experts_out
self.experts_hidden = experts_hidden
self.towers_hidden = towers_hidden
self.tasks = tasks
self.softmax = nn.Softmax(dim=1)
self.experts = nn.ModuleList([Expert(self.input_size, self.experts_out, self.experts_hidden) for i in range(self.num_experts)])
self.w_gates = nn.ParameterList([nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True) for i in range(self.tasks)])
self.towers = nn.ModuleList([Tower(self.experts_out, 1, self.towers_hidden) for i in range(self.tasks)])
def forward(self, x):
experts_o = [e(x) for e in self.experts]
experts_o_tensor = torch.stack(experts_o)
gates_o = [self.softmax(x @ g) for g in self.w_gates]
tower_input = [g.t().unsqueeze(2).expand(-1, -1, self.experts_out) * experts_o_tensor for g in gates_o]
tower_input = [torch.sum(ti, dim=0) for ti in tower_input]
final_output = [t(ti) for t, ti in zip(self.towers, tower_input)]
return final_output
model = MMOE(input_size=499, num_experts=6, experts_out=16, experts_hidden=32, towers_hidden=8, tasks=2)
torchinfo.summary(model, input_size=(1024, 499),
col_names=[
"kernel_size",
"input_size",
"output_size",
"num_params",
"trainable",
"mult_adds"
],
col_width=16,
row_settings=["var_names", "depth"],
)
I was on v1.7.1 and I got something like this.
========================================================================================================================================
Layer (type (var_name):depth-idx) Kernel Shape Input Shape Output Shape Param # Trainable Mult-Adds
========================================================================================================================================
MMOE (MMOE) -- [1024, 499] [1024, 1] 5,988 True --
├─ModuleList (experts): 1-1 -- -- -- -- True --
│ └─Expert (0): 2-1 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-1 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-2 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-3 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-4 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (1): 2-2 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-5 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-6 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-7 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-8 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (2): 2-3 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-9 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-10 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-11 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-12 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (3): 2-4 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-13 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-14 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-15 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-16 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (4): 2-5 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-17 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-18 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-19 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-20 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (5): 2-6 -- [1024, 499] [1024, 16] -- True --
│ │ └─Linear (fc1): 3-21 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-22 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-23 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-24 -- [1024, 32] [1024, 16] 528 True 540,672
├─Softmax (softmax): 1-2 -- [1024, 6] [1024, 6] -- -- --
├─Softmax (softmax): 1-3 -- [1024, 6] [1024, 6] -- -- --
├─ModuleList (towers): 1-4 -- -- -- -- True --
│ └─Tower (0): 2-7 -- [1024, 16] [1024, 1] -- True --
│ │ └─Linear (fc1): 3-25 -- [1024, 16] [1024, 8] 136 True 139,264
│ │ └─ReLU (relu): 3-26 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Dropout (dropout): 3-27 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Linear (fc2): 3-28 -- [1024, 8] [1024, 1] 9 True 9,216
│ │ └─Sigmoid (sigmoid): 3-29 -- [1024, 1] [1024, 1] -- -- --
│ └─Tower (1): 2-8 -- [1024, 16] [1024, 1] -- True --
│ │ └─Linear (fc1): 3-30 -- [1024, 16] [1024, 8] 136 True 139,264
│ │ └─ReLU (relu): 3-31 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Dropout (dropout): 3-32 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Linear (fc2): 3-33 -- [1024, 8] [1024, 1] 9 True 9,216
│ │ └─Sigmoid (sigmoid): 3-34 -- [1024, 1] [1024, 1] -- -- --
========================================================================================================================================
Total params: 105,446
Trainable params: 105,446
Non-trainable params: 0
Total mult-adds (M): 101.84
========================================================================================================================================
Input size (MB): 2.04
Forward/backward pass size (MB): 2.51
Params size (MB): 0.40
Estimated Total Size (MB): 4.95
========================================================================================================================================
This seems to be great, nearly all things are included. But nn.ParameterList (w_gates) is omitted.
I went through #54 and #84.
Seems to be mentioned before and I downgraded it to v1.7.0
I got result as follow, which includes nn.ParameterList, but result itself seems to be incorrect?
========================================================================================================================================
Layer (type (var_name):depth-idx) Kernel Shape Input Shape Output Shape Param # Trainable Mult-Adds
========================================================================================================================================
MMOE (MMOE) -- [1024, 499] [1024, 1] -- True --
├─Softmax (softmax): 1-6 -- [1024, 6] [1024, 6] -- -- --
├─ModuleList (experts): 1-2 -- -- -- 16,528 True --
│ └─Expert (0): 2-1 -- [1024, 499] [1024, 16] 16,528 True --
│ │ └─Linear (fc1): 3-2 -- [1024, 499] [1024, 32] (recursive) True 16,384,000
│ │ └─Linear (fc1): 3-2 -- [1024, 499] [1024, 32] (recursive) True 16,384,000
│ │ └─ReLU (relu): 3-3 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-4 -- [1024, 32] [1024, 32] -- -- --
│ └─Expert (1): 2-3 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (0): 2 -- -- -- -- -- --
│ │ └─Linear (fc2): 3-5 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (1): 2-3 -- [1024, 499] [1024, 16] (recursive) True --
│ │ └─Linear (fc1): 3-6 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-7 -- [1024, 32] [1024, 32] -- -- --
│ └─Expert (2): 2-5 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (1): 2 -- -- -- -- -- --
│ │ └─Dropout (dropout): 3-8 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-9 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (2): 2-5 -- [1024, 499] [1024, 16] (recursive) True --
│ │ └─Linear (fc1): 3-10 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ └─Expert (3): 2-7 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (2): 2 -- -- -- -- -- --
│ │ └─ReLU (relu): 3-11 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-12 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-13 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (3): 2-7 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (4): 2-10 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (3): 2 -- -- -- -- -- --
│ │ └─Linear (fc1): 3-14 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-15 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-16 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-17 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (5): 2-12 -- [1024, 499] [1024, 16] (recursive) True --
│ └─Expert (4): 2-10 -- [1024, 499] [1024, 16] (recursive) True --
│ │ └─Linear (fc1): 3-18 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-19 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-20 -- [1024, 32] [1024, 32] -- -- --
├─ParameterList (w_gates): 1-3 -- -- -- 5,988 True --
├─ModuleList (towers): 1-4 -- -- -- -- True --
│ └─Tower (0): 2-13 -- [1024, 16] [1024, 1] (recursive) True --
├─ModuleList (experts): 1-2 -- -- -- 16,528 True --
│ └─Expert (4): 2 -- -- -- -- -- --
│ │ └─Linear (fc2): 3-21 -- [1024, 32] [1024, 16] 528 True 540,672
│ └─Expert (5): 2-12 -- [1024, 499] [1024, 16] (recursive) True --
│ │ └─Linear (fc1): 3-22 -- [1024, 499] [1024, 32] 16,000 True 16,384,000
│ │ └─ReLU (relu): 3-23 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Dropout (dropout): 3-24 -- [1024, 32] [1024, 32] -- -- --
│ │ └─Linear (fc2): 3-25 -- [1024, 32] [1024, 16] 528 True 540,672
├─Softmax (softmax): 1-5 -- [1024, 6] [1024, 6] -- -- --
├─ModuleList (towers): 1-4 -- -- -- -- True --
│ └─Tower (1): 2 -- -- -- -- -- --
│ │ └─Linear (fc2): 3-35 -- [1024, 8] [1024, 1] (recursive) True 9,216
├─Softmax (softmax): 1-6 -- [1024, 6] [1024, 6] -- -- --
├─ModuleList (towers): 1-4 -- -- -- -- True --
│ └─Tower (0): 2-13 -- [1024, 16] [1024, 1] (recursive) True --
│ │ └─Linear (fc1): 3-27 -- [1024, 16] [1024, 8] 136 True 139,264
│ │ └─ReLU (relu): 3-28 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Dropout (dropout): 3-29 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Linear (fc2): 3-30 -- [1024, 8] [1024, 1] 9 True 9,216
│ │ └─Sigmoid (sigmoid): 3-31 -- [1024, 1] [1024, 1] -- -- --
│ └─Tower (1): 2-14 -- [1024, 16] [1024, 1] 9 True --
│ │ └─Linear (fc1): 3-32 -- [1024, 16] [1024, 8] 136 True 139,264
│ │ └─ReLU (relu): 3-33 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Dropout (dropout): 3-34 -- [1024, 8] [1024, 8] -- -- --
│ │ └─Linear (fc2): 3-35 -- [1024, 8] [1024, 1] (recursive) True 9,216
│ │ └─Sigmoid (sigmoid): 3-36 -- [1024, 1] [1024, 1] -- -- --
========================================================================================================================================
Total params: 105,446
Trainable params: 105,446
Non-trainable params: 0
Total mult-adds (M): 118.24
========================================================================================================================================
Input size (MB): 2.04
Forward/backward pass size (MB): 2.24
Params size (MB): 0.36
Estimated Total Size (MB): 4.64
========================================================================================================================================
Hi, the reason for this change is 2-fold.
First, there was a recent change in recording module layers for result. After this change, it only records modules that are called (via their respective forward
function).
The second reason is that torchinfo as of now does not record torch functions (such as +
or @
).
In your case, none of the parameters inside ParameterList
is called (with forward
function) and but used only for @
which is not recorded by torchinfo.
Regarding this, there is a new discussion (by me) about the new tracing mechanism that I think would resolve this problem. If you are interested in helping implementation of this, please see here #192 .
Thanks for notification. I will check that out.
Hello. It seems to still be ommited