Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

question about the MACs of nn.BatchNorm2d

DaMiBear opened this issue · comments

Hi, happy new year!
I'm confused about the method of calculating nn.BatchNorm2d MACs.

def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
# TODO: add test cases
# https://github.com/Lyken17/pytorch-OpCounter/issues/124
# y = (x - mean) / sqrt(eps + var) * weight + bias
x = x[0]
# bn is by default fused in inference
flops = calculate_norm(x.numel())
if (getattr(m, 'affine', False) or getattr(m, 'elementwise_affine', False)):
flops *= 2
m.total_ops += flops

def calculate_norm(input_size):
"""input is a number not a array or tensor"""
return torch.DoubleTensor([2 * input_size])

In my opinion: in calculate_norm(input_size), the 2 * input_size already means the MACs of subtract(mean), divide(var), mul(weight) and add(bias). But why is the flops(w.r.t MACs) multiplied by 2 again in the next?