thop calculates torch.nn module params incorrectly
qsimeon opened this issue · comments
!pip install -q thop
import torch
from utils import DEVICE
from thop import profile, clever_format
# Display the DEVICE
print(f"DEVICE: {DEVICE}")
# Assuming 'model' is your PyTorch model and 'input' is a tensor representing the input to the model
input = torch.randn(1, 10) # example input for a single length 10 sequnce
input = input.to(DEVICE)
model = torch.nn.Linear(10, 11) # example model
model = model.to(DEVICE)
# model.eval()
# Measure the FLOPs
macs, params = profile(model, inputs=(input,), verbose=False)
print(f"MACs: {macs}, params: {params}")
macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}, params: {params}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {num_params}")
Outputs:
DEVICE: cpu
MACs: 11000.0, params: 0
MACs: 11.000K, params: 0.000B
Parameters: 1012
For some reason the resnet example from the README seems to work ok.
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input,))
macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}, params: {params}")
Output:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs: 4.134G, params: 25.557M
Why wouldn't this work for such basic torch.nn.Modules like Linear and LSTM but work for more complicated models?
@qsimeon Try to use torch-operation-counter since it wrap the native pytorch operations so complex model should work as well