TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inaccurate Mult-Adds Estimation for Transformers

Yiming-M opened this issue · comments

Describe the bug

For ViT, the returned total mult-adds from torchinfo.summary is much smaller than that reported in other websites.

To Reproduce

Code snippet:

from torchinfo import summary
from torchvision.models import vit_b_16
vit = vit_b_16()
input_size = 1, 3, 224, 224
summary(vit, input_size)

Output:

...
Total params: 86,567,656
Trainable params: 86,567,656
Non-trainable params: 0
Total mult-adds (M): 173.23
===============================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 104.09
Params size (MB): 232.27
Estimated Total Size (MB): 336.96

Expected behavior

From other resources such as MMClassification and PapersWithCode, the number of flops is 33.03G. I understand that the number of mult-adds is different than the number of flops, but in the case of transformers, where matrix multiplication accounts for a large proportion of overall computation, these two numbers should be similar (not like 33.03G and 173.23M!)

Screenshots
If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

  • OS: macOS Ventura 13.2 (M1Pro)
  • Python: 3.10.9
  • Package Version (torchinfo): 1.7.2

I meet the same question and hope developers to pay attention to it, Thanks a lot.

encountered similar bug: The MACs of MultiheadAttention module doesn't get counted

The problem is that currently, torchinfo only traces nn.Modules, not functions. Transformer Modules often use shortcut functions, so they often don't get traced.

Discussion #192 proposes a tracing mechanism that would fix this issue, but it is a big change. If anyone is up to implementing the change, I think that @TylerYep would be happy about it.