Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot correctly parse some import paths when using config file for LightningCLI

zengchang233 opened this issue · comments

Bug description

I would like to use CLI with config.yaml file to manage my project. But I found some classes in TorchAudio cannot be correctly parsed. The error involves importing a window function from Pytorch. For example, if the class doesn't have a window function as the argument explicitly like MFCC or LFCC, it can be parsed and imported successfully. Otherwise, it fails in parsing the class or just identifies the import path as a string (like MelSpectrogram). I know it may be attributed to the jsonargparse lib, but I think it is also necessary to acknowledge the lightning community.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

from lightning.pytorch.cli import LightningCLI

cli = LightningCLI(run=True, save_config_kwargs={"overwrite": True})
model:
  class_path: models.AudioSystem
  init_args:
    model:
      class_path: models.AudioModel
      init_args:
        feat_dim: 80 # n_feat
        trunc_len: 750
        num_classes: 1000
    feature_extractor:
      class_path: torchaudio.transforms.MelSpectrogram
      init_args:
        sample_rate: 16000
        n_mels: 80
        n_fft: 512
        win_length: 512
        hop_length: 160
        window_fn: torch.hann_window # It is a default argument. Whether configuring it explicitly or not, the error occurred.
    delta_order: 2

Error messages and logs

# Configure the window function explicitly
ValueError: Only possible to serialize an importable object, given <built-in method hann_window of type object at 0x7feed842f840>: module 'torch' has no attribute '_VariableFunctionsClass'

# Don't configure the window function in config file, using the default argument.
Traceback (most recent call last):
  File "/workspace/zengchang/work/antispoofing/train.py", line 128, in <module>
    main()
  File "/workspace/zengchang/work/antispoofing/train.py", line 125, in main
    cli_main()
  File "/workspace/zengchang/work/antispoofing/train.py", line 68, in cli_main
    cli = CustomCLI(
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 385, in __init__
    self.instantiate_classes()
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 535, in instantiate_classes
    self.config_init = self.parser.instantiate_classes(self.config)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
    cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1196, in instantiate_classes
    cfg[subcommand] = subparser.instantiate_classes(cfg[subcommand], instantiate_groups=instantiate_groups)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
    cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1187, in instantiate_classes
    parent[key] = component.instantiate_classes(value)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 565, in instantiate_classes
    value[num] = adapt_typehints(
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 952, in adapt_typehints
    val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 1173, in adapt_class_type
    init_args = parser.instantiate_classes(init_args)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_deprecated.py", line 141, in patched_instantiate_classes
    cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_core.py", line 1187, in instantiate_classes
    parent[key] = component.instantiate_classes(value)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 565, in instantiate_classes
    value[num] = adapt_typehints(
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 952, in adapt_typehints
    val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_typehints.py", line 1187, in adapt_class_type
    return instantiator_fn(val_class, **{**init_args, **dict_kwargs})
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/jsonargparse/_common.py", line 148, in default_class_instantiator
    return class_type(*args, **kwargs)
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py", line 593, in __init__
    self.spectrogram = Spectrogram(
  File "/opt/conda/envs/asvspoof5/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py", line 85, in __init__
    window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
TypeError: 'str' object is not callable

I have traced the error for using default argument. I found the window_fn is a string torch._VariableFunctionsClass.hann_window. It seems the import path is treated as a string. But even though this path is imported, it will raise an error since the window function should be imported as from torch import hann_window or import torch.hann_window, rather than import torch._VariableFunctionsClass.hann_window.

Environment

import pkg_resources

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.2.5
    - lightning-bolts: 0.7.0
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 1.9.5
    - pytorch-warmup: 0.1.1
    - torch: 2.2.2+cu121
    - torchaudio: 2.2.2+cu121
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.17.2+cu121
  • Packages:
    - absl-py: 2.1.0
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - audioread: 3.0.1
    - babel: 2.15.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - certifi: 2024.6.2
    - cffi: 1.16.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cosine-annealing-warmup: 2.0
    - cycler: 0.12.1
    - cython: 3.0.10
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - docker-pycreds: 0.4.0
    - docstring-parser: 0.16
    - exceptiongroup: 1.2.1
    - executing: 2.0.1
    - fastjsonschema: 2.20.0
    - filelock: 3.15.1
    - fonttools: 4.53.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2024.2.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - grpcio: 1.64.1
    - h11: 0.14.0
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - idna: 3.7
    - importlib-resources: 6.4.0
    - ipdb: 0.13.13
    - ipykernel: 6.29.4
    - ipython: 8.25.0
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonargparse: 4.29.0
    - jsonpointer: 3.0.0
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.1
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.2
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.2
    - jupyterlab-widgets: 3.0.11
    - kiwisolver: 1.4.5
    - lazy-loader: 0.4
    - librosa: 0.10.2.post1
    - lightning: 2.2.5
    - lightning-bolts: 0.7.0
    - lightning-utilities: 0.11.2
    - llvmlite: 0.43.0
    - markdown: 3.6
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mistune: 3.0.2
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - notebook: 7.2.1
    - notebook-shim: 0.2.4
    - numba: 0.60.0
    - numpy: 1.23.1
    - nvidia-cublas-cu11: 11.11.3.6
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu11: 11.8.87
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu11: 11.8.89
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu11: 11.8.89
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu11: 8.7.0.84
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu11: 10.3.0.86
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu11: 11.4.1.48
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu11: 11.7.5.86
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu11: 2.19.3
    - nvidia-nccl-cu12: 2.19.3
    - nvidia-nvjitlink-cu12: 12.5.40
    - nvidia-nvtx-cu11: 11.8.86
    - nvidia-nvtx-cu12: 12.1.105
    - omegaconf: 2.3.0
    - overrides: 7.7.0
    - packaging: 24.1
    - pandas: 2.2.2
    - pandocfilters: 1.5.1
    - parso: 0.8.4
    - pexpect: 4.9.0
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - pooch: 1.8.2
    - praat-parselmouth: 0.4.3
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.47
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pycparser: 2.22
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - python-dateutil: 2.9.0.post0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 1.9.5
    - pytorch-warmup: 0.1.1
    - pytz: 2024.1
    - pyworld: 0.3.4
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - referencing: 0.35.1
    - requests: 2.32.3
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rpds-py: 0.18.1
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - send2trash: 1.8.3
    - sentry-sdk: 2.5.1
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - soundfile: 0.12.1
    - soupsieve: 2.5
    - soxr: 0.3.7
    - stack-data: 0.6.3
    - sympy: 1.12
    - tensorboard: 2.17.0
    - tensorboard-data-server: 0.7.2
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tinycss2: 1.3.0
    - tomli: 2.0.1
    - torch: 2.2.2+cu121
    - torchaudio: 2.2.2+cu121
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.17.2+cu121
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - triton: 2.2.0
    - types-python-dateutil: 2.9.0.20240316
    - typeshed-client: 2.5.1
    - typing-extensions: 4.12.2
    - tzdata: 2024.1
    - uri-template: 1.3.0
    - urllib3: 2.2.1
    - wandb: 0.17.1
    - wcwidth: 0.2.13
    - webcolors: 24.6.0
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.3
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - yarl: 1.9.4
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 5.10.101-1.el8.ssai.x86_64
    - version: #1 SMP Mon Apr 25 12:35:34 UTC 2022

More info

No response

Hi @zengchang233
Could you report this to jsonargparse please? It's the library that does the parsing and instantiation.

It has been fixed in jsonargparse.