Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to define the rule for 3rd party module?

moshicaixi opened this issue · comments

Hi. thank you for your excellent work!

My research area is 3D point clouds, including shape classification and semantic segmentation. In my pytorch project, there are some 3rd party libs, such as ball query algorithm in PointNet++, which is customized CUDA function. In this situation, how could i define the rule for calculating the macs and params? I would appreciate it if you can give me some advice.

100

For newly defined modules (no matter in CUDA or C or Python), there should be a corresponding python class wrapper. You can register counting functions with that wrapper.

As for reference, you can check how THOP counts for PyTorch modules

https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/profile.py#L21-L65