pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.

Home Page:https://pytorch.org/torcheval

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

get_module_summary assertion fail

quancs opened this issue Β· comments

πŸ› Describe the bug

I have a sub-module shared by different modules, like the l below. get_module_summary reports an Assertion Error for this situation.

import torch
from torch import nn
from torcheval.tools.module_summary import get_module_summary

l = nn.Linear(10, 10)
s1 = nn.Sequential(l)
s2 = nn.Sequential(l)
s = nn.Sequential(s1, s2)

ms = get_module_summary(s, module_args=(torch.randn(10, 100, 10, dtype=torch.float32),))
flops_forward_eval, flops_back_eval = ms.flops_forward, ms.flops_backward
params_eval = ms.num_parameters

print(flops_forward_eval, flops_back_eval, params_eval)

The error:

Traceback (most recent call last):
  File "/mnt/home/x/projects/Y/test_x.py", line 10, in <module>
    ms = get_module_summary(s, module_args=(torch.randn(10, 100, 10, dtype=torch.float32),))
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torcheval/tools/module_summary.py", line 348, in get_module_summary
    module_summary_data = _get_module_flops_and_activation_sizes(
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torcheval/tools/module_summary.py", line 258, in _get_module_flops_and_activation_sizes
    res = module(*module_args, **module_kwargs)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1215, in _call_impl
    hook_result = hook(self, input, result)
  File "/mnt/home/x/miniconda3/lib/python3.9/site-packages/torcheval/tools/flops.py", line 306, in f
    assert parents[-1] == name
AssertionError

Versions

Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 8.3.1 20190311 (Red Hat 8.3.1-3)
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.17

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.124
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 530.30.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
CPU MHz: 2250.000
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4491.63
Virtualization: AMD-V
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-15
NUMA node1 CPU(s): 16-31
NUMA node2 CPU(s): 32-47
NUMA node3 CPU(s): 48-63
NUMA node4 CPU(s): 64-79
NUMA node5 CPU(s): 80-95
NUMA node6 CPU(s): 96-111
NUMA node7 CPU(s): 112-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc art rep_good nopl nonstop_tsc extd_apicid aperfmperf eagerfpu pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_l2 cpb cat_l3 cdp_l3 hw_pstate sme retpoline_amd ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip overflow_recov succor smca

Versions of relevant libraries:
[pip3] flake8==3.7.9
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.0.0
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.13.1
[pip3] torch-complex==0.4.3
[pip3] torch-optimizer==0.3.0
[pip3] torch-stoi==0.1.2
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==0.13.1
[pip3] torchdata==0.4.1
[pip3] torcheval==0.0.6
[pip3] torchinfo==1.7.2
[pip3] torchmetrics==0.11.4
[pip3] torchtnt==0.0.7
[pip3] torchvision==0.14.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.6.0 hecad31d_10 conda-forge
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.5 py39h14f4228_0
[conda] numpy-base 1.23.5 py39h31eccc5_0
[conda] pytorch 1.13.1 py3.9_cuda11.6_cudnn8.3.2_0 pytorch
[conda] pytorch-cuda 11.6 h867d48c_1 pytorch
[conda] pytorch-lightning 2.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch-complex 0.4.3 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torch-stoi 0.1.2 pypi_0 pypi
[conda] torch-tb-profiler 0.4.0 pypi_0 pypi
[conda] torchaudio 0.13.1 py39_cu116 pytorch
[conda] torchdata 0.4.1 pypi_0 pypi
[conda] torcheval 0.0.6 pypi_0 pypi
[conda] torchinfo 1.7.2 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchtnt 0.0.7 pypi_0 pypi
[conda] torchvision 0.14.1 py39_cu116 pytorch

commented

Thanks for reporting quancs. We'll take a look, I'm not sure if we can support these kinds of circular cases.

@JKSenthil since you're moving this to tnt, you might want to consider this post.