Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

thop calculates torch.nn module params incorrectly

qsimeon opened this issue · comments

!pip install -q thop

import torch
from utils import DEVICE
from thop import profile, clever_format

# Display the DEVICE
print(f"DEVICE: {DEVICE}")

# Assuming 'model' is your PyTorch model and 'input' is a tensor representing the input to the model
input = torch.randn(1, 10)  # example input for a single length 10 sequnce
input = input.to(DEVICE)

model = torch.nn.Linear(10, 11)  # example model
model = model.to(DEVICE)
# model.eval()

# Measure the FLOPs
macs, params = profile(model, inputs=(input,), verbose=False)

print(f"MACs: {macs}, params: {params}")

macs, params = clever_format([macs, params], "%.3f")

print(f"MACs: {macs}, params: {params}")

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {num_params}")

Outputs:

DEVICE: cpu
MACs: 11000.0, params: 0
MACs: 11.000K, params: 0.000B
Parameters: 1012


For some reason the resnet example from the README seems to work ok.

from torchvision.models import resnet50
from thop import profile

model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input,))

macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}, params: {params}")

Output:

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs: 4.134G, params: 25.557M

Why wouldn't this work for such basic torch.nn.Modules like Linear and LSTM but work for more complicated models?

@qsimeon Try to use torch-operation-counter since it wrap the native pytorch operations so complex model should work as well