TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Nested module not printing the output shape

sup3rgiu opened this issue · comments

Describe the bug
In a nested-module scenario, the nn.Module output shape is not printed if the depth is not high enough.

image

Increasing the depth shows the output shape of the nested layers, but this could be quite inconvenient if the model is deeply nested and I'm only interested in the output shape of the "main" module.

image

To Reproduce

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
            module = nn.Module()
            module.block = block
            self.module.append(module)
        
    def forward(self, x):
        for i in range(4):
            for j in range(5):
                x = self.module[i].block[j](x)
        return x

module = MyModule()

summary(module, input_size=[(3, 10)], dtypes=[torch.float32], depth=2)

Expected behavior
Always print the output shape of the most nested layer/module with respect to the selected depth.

I think the problem is that the nn.Module submodules of that object don't have forward methods and are never actually called. This variant using nn.Sequential does what you want it to:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
            self.module.append(nn.Sequential(*block))
        
    def forward(self, x):
        for i in range(4):
            x = self.module[i](x)
        return x

as does defining a custom submodule class with a forward method and using that class for the submodules. Does either of those resolve the issue?

Actually, the problem arose when trying to get a nice summary table for the VAE of Stable Diffusion, where as you can see they use this approach: https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/diffusionmodules/model.py#L542

However, I see your point about the missing forward method, and indeed it would make no sense to try to compute an exact output shape for such nn.Module.

Still, I think it might be useful to try to find a workaround to print an "expected" output shape. Maybe we could set it to the output shape of the last used .submodule of such nn.Module.

For instance, in the following example:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
                attn.append(nn.Linear(10, 10))
            module = nn.Module()
            module.block = block
            module.attn = attn
            self.module.append(module)
        
    def forward(self, x):
        for i in range(4):
            for j in range(5):
                x = self.module[i].block[j](x)
                x = self.module[i].attn[j](x)
        return x

the output shape of the first nn.Module would be the output shape of self.module[0].attn[4](x) since .attn[4] is the last thing called in self.module[0], the output shape of the second nn.Module would be the output shape of self.module[1].attn[4](x), and so on.

But I don't know if torchinfo could be easily extended to keep track of the last used "submodule" of a nn.Module.

Personally, I'm not sure that this would be desired behavior. If an nn.Module is used just as a container for submodules, as it is here, then one can imagine the submodules being used in lots of different configurations and different types of "output" being extracted from inside the nn.Module. Picking one of these as the "output shape" could be more confusing than the current behavior, which is at least clear and unambiguous (the module has no output and thus no "output shape").

But the actual maintainers should probably weigh in -- I'm just a guy looking for issues to work on :)

Yes, I agree with you that it's probably not the most desirable behavior. Even if we somehow specify in the output summary that these shapes are "estimated", it could probably be confusing.

Since I recognize that this is probably a non-issue and won't be fixed, we should probably close the issue as "Won't fix".

Agreed that this behavior is ambiguous and hard to maintain. Thanks for the discussion, hopefully this thread helps answer future issues like this.