Cannot correctly parse some import paths when using config file for LightningCLI
zengchang233 opened this issue · comments
Bug description
I would like to use CLI with config.yaml file to manage my project. But I found some classes in TorchAudio cannot be correctly parsed. The error involves importing a window function from Pytorch. For example, if the class doesn't have a window function as the argument explicitly like MFCC or LFCC, it can be parsed and imported successfully. Otherwise, it fails in parsing the class or just identifies the import path as a string (like MelSpectrogram). I know it may be attributed to the jsonargparse lib, but I think it is also necessary to acknowledge the lightning community.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
from lightning.pytorch.cli import LightningCLI
cli = LightningCLI(run=True, save_config_kwargs={"overwrite": True})
model:
class_path: models.AudioSystem
init_args:
model:
class_path: models.AudioModel
init_args:
feat_dim: 80 # n_feat
trunc_len: 750
num_classes: 1000
feature_extractor:
class_path: torchaudio.transforms.MelSpectrogram
init_args:
sample_rate: 16000
n_mels: 80
n_fft: 512
win_length: 512
hop_length: 160
window_fn: torch.hann_window # It is a default argument. Whether configuring it explicitly or not, the error occurred.
delta_order: 2
Error messages and logs
# Configure the window function explicitly
ValueError: Only possible to serialize an importable object, given <built-in method hann_window of type object at 0x7feed842f840>: module 'torch' has no attribute '_VariableFunctionsClass'
# Don't configure the window function in config file, using the default argument.
Traceback (most recent call last):
File "/workspace/zengchang/work/antispoofing/train.py", line 128, in <module>
main()
File "/workspace/zengchang/work/antispoofing/train.py", line 125, in main
cli_main()
File "/workspace/zengchang/work/antispoofing/train.py", line 68, in cli_main
cli = CustomCLI(
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 385, in __init__
self.instantiate_classes()
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 535, in instantiate_classes
self.config_init = self.parser.instantiate_classes(self.config)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1196, in instantiate_classes
cfg[subcommand] = subparser.instantiate_classes(cfg[subcommand], instantiate_groups=instantiate_groups)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1187, in instantiate_classes
parent[key] = component.instantiate_classes(value)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 565, in instantiate_classes
value[num] = adapt_typehints(
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 952, in adapt_typehints
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 1173, in adapt_class_type
init_args = parser.instantiate_classes(init_args)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1187, in instantiate_classes
parent[key] = component.instantiate_classes(value)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 565, in instantiate_classes
value[num] = adapt_typehints(
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 952, in adapt_typehints
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 1187, in adapt_class_type
return instantiator_fn(val_class, **{**init_args, **dict_kwargs})
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_common.py", line 148, in default_class_instantiator
return class_type(*args, **kwargs)
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py", line 593, in __init__
self.spectrogram = Spectrogram(
File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py", line 85, in __init__
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
TypeError: 'str' object is not callable
I have traced the error for using default argument. I found the window_fn
is a string torch._VariableFunctionsClass.hann_window
. It seems the import path is treated as a string. But even though this path is imported, it will raise an error since the window function should be imported as from torch import hann_window
or import torch.hann_window
, rather than import torch._VariableFunctionsClass.hann_window
.
Environment
import pkg_resources
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- available: True
- version: 12.1 - Lightning:
- lightning: 2.2.5
- lightning-bolts: 0.7.0
- lightning-utilities: 0.11.2
- pytorch-lightning: 1.9.5
- pytorch-warmup: 0.1.1
- torch: 2.2.2+cu121
- torchaudio: 2.2.2+cu121
- torchmetrics: 1.4.0.post0
- torchvision: 0.17.2+cu121 - Packages:
- absl-py: 2.1.0
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- asttokens: 2.4.1
- async-lru: 2.0.4
- async-timeout: 4.0.3
- attrs: 23.2.0
- audioread: 3.0.1
- babel: 2.15.0
- beautifulsoup4: 4.12.3
- bleach: 6.1.0
- certifi: 2024.6.2
- cffi: 1.16.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.2.1
- cosine-annealing-warmup: 2.0
- cycler: 0.12.1
- cython: 3.0.10
- debugpy: 1.8.1
- decorator: 5.1.1
- defusedxml: 0.7.1
- docker-pycreds: 0.4.0
- docstring-parser: 0.16
- exceptiongroup: 1.2.1
- executing: 2.0.1
- fastjsonschema: 2.20.0
- filelock: 3.15.1
- fonttools: 4.53.0
- fqdn: 1.5.1
- frozenlist: 1.4.1
- fsspec: 2024.2.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- grpcio: 1.64.1
- h11: 0.14.0
- httpcore: 1.0.5
- httpx: 0.27.0
- idna: 3.7
- importlib-resources: 6.4.0
- ipdb: 0.13.13
- ipykernel: 6.29.4
- ipython: 8.25.0
- ipywidgets: 8.1.3
- isoduration: 20.11.0
- jedi: 0.19.1
- jinja2: 3.1.4
- joblib: 1.4.2
- json5: 0.9.25
- jsonargparse: 4.29.0
- jsonpointer: 3.0.0
- jsonschema: 4.22.0
- jsonschema-specifications: 2023.12.1
- jupyter: 1.0.0
- jupyter-client: 8.6.2
- jupyter-console: 6.6.3
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.1
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.2
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.2
- jupyterlab-widgets: 3.0.11
- kiwisolver: 1.4.5
- lazy-loader: 0.4
- librosa: 0.10.2.post1
- lightning: 2.2.5
- lightning-bolts: 0.7.0
- lightning-utilities: 0.11.2
- llvmlite: 0.43.0
- markdown: 3.6
- markupsafe: 2.1.5
- matplotlib: 3.9.0
- matplotlib-inline: 0.1.7
- mistune: 3.0.2
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- notebook: 7.2.1
- notebook-shim: 0.2.4
- numba: 0.60.0
- numpy: 1.23.1
- nvidia-cublas-cu11: 11.11.3.6
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu11: 11.8.87
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu11: 11.8.89
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu11: 11.8.89
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu11: 8.7.0.84
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu11: 10.3.0.86
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu11: 11.4.1.48
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu11: 11.7.5.86
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu11: 2.19.3
- nvidia-nccl-cu12: 2.19.3
- nvidia-nvjitlink-cu12: 12.5.40
- nvidia-nvtx-cu11: 11.8.86
- nvidia-nvtx-cu12: 12.1.105
- omegaconf: 2.3.0
- overrides: 7.7.0
- packaging: 24.1
- pandas: 2.2.2
- pandocfilters: 1.5.1
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- pooch: 1.8.2
- praat-parselmouth: 0.4.3
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.47
- protobuf: 4.25.3
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pycparser: 2.22
- pygments: 2.18.0
- pyparsing: 3.1.2
- python-dateutil: 2.9.0.post0
- python-json-logger: 2.0.7
- pytorch-lightning: 1.9.5
- pytorch-warmup: 0.1.1
- pytz: 2024.1
- pyworld: 0.3.4
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- qtconsole: 5.5.2
- qtpy: 2.4.1
- referencing: 0.35.1
- requests: 2.32.3
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rpds-py: 0.18.1
- scikit-learn: 1.5.0
- scipy: 1.13.1
- send2trash: 1.8.3
- sentry-sdk: 2.5.1
- setproctitle: 1.3.3
- setuptools: 70.0.0
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- soundfile: 0.12.1
- soupsieve: 2.5
- soxr: 0.3.7
- stack-data: 0.6.3
- sympy: 1.12
- tensorboard: 2.17.0
- tensorboard-data-server: 0.7.2
- terminado: 0.18.1
- threadpoolctl: 3.5.0
- tinycss2: 1.3.0
- tomli: 2.0.1
- torch: 2.2.2+cu121
- torchaudio: 2.2.2+cu121
- torchmetrics: 1.4.0.post0
- torchvision: 0.17.2+cu121
- tornado: 6.4.1
- tqdm: 4.66.4
- traitlets: 5.14.3
- triton: 2.2.0
- types-python-dateutil: 2.9.0.20240316
- typeshed-client: 2.5.1
- typing-extensions: 4.12.2
- tzdata: 2024.1
- uri-template: 1.3.0
- urllib3: 2.2.1
- wandb: 0.17.1
- wcwidth: 0.2.13
- webcolors: 24.6.0
- webencodings: 0.5.1
- websocket-client: 1.8.0
- werkzeug: 3.0.3
- wheel: 0.43.0
- widgetsnbextension: 4.0.11
- yarl: 1.9.4 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 5.10.101-1.el8.ssai.x86_64
- version: #1 SMP Mon Apr 25 12:35:34 UTC 2022
More info
No response
Hi @zengchang233
Could you report this to jsonargparse please? It's the library that does the parsing and instantiation.
It has been fixed in jsonargparse.