Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Possible bug in recognizing `mps` accelerator even though PyTorch seems to register the `mps` device?

adam2392 opened this issue · comments

Bug description

Hi, I'm a new user to pl, so apologies if this ends up not being a bug.

I am running PyTorch and pytorch-lightning within a Jupyter notebook on my MacOSX M1. My conda environment is from miniforge, which supports the Arm64 distro. As expected, the following print statements show that mps is available.

print(torch.backends.mps.is_available())
> True
print(torch.cuda.is_available())
> False

However, when I transition to initializing the Trainer:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    devices=1,
    accelerator='mps',
)

I get an error shown below.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import pytorch_lightning as pl
import torch

trainer = pl.Trainer(
    max_epochs=2,
    devices=1,
    accelerator='mps',
)

Error messages and logs

MisconfigurationException                 Traceback (most recent call last)
Cell In[13], line 1
----> 1 trainer = pl.Trainer(
      2     max_epochs=max_epochs,
      3     logger=logger,
      4     devices=1,
      5     callbacks=[checkpoint_callback] if wandb else [],
      6     check_val_every_n_epoch=check_val_every_n_epoch,
      7     accelerator=accelerator,
      8 )

File ~/miniforge3/envs/cdrl/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:70, in _defaults_from_env_vars.<locals>.insert_env_defaults(self, *args, **kwargs)
     67 kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
     69 # all args were already moved to kwargs
---> 70 return fn(self, **kwargs)

File ~/miniforge3/envs/cdrl/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:401, in Trainer.__init__(self, accelerator, strategy, devices, num_nodes, precision, logger, callbacks, fast_dev_run, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, overfit_batches, val_check_interval, check_val_every_n_epoch, num_sanity_val_steps, log_every_n_steps, enable_checkpointing, enable_progress_bar, enable_model_summary, accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm, deterministic, benchmark, inference_mode, use_distributed_sampler, profiler, detect_anomaly, barebones, plugins, sync_batchnorm, reload_dataloaders_every_n_epochs, default_root_dir)
    398 # init connectors
    399 self._data_connector = _DataConnector(self)
--> 401 self._accelerator_connector = _AcceleratorConnector(
    402     devices=devices,
    403     accelerator=accelerator,
    404     strategy=strategy,
    405     num_nodes=num_nodes,
    406     sync_batchnorm=sync_batchnorm,
    407     benchmark=benchmark,
    408     use_distributed_sampler=use_distributed_sampler,
    409     deterministic=deterministic,
    410     precision=precision,
    411     plugins=plugins,
    412 )
    413 self._logger_connector = _LoggerConnector(self)
    414 self._callback_connector = _CallbackConnector(self)

File ~/miniforge3/envs/cdrl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:150, in _AcceleratorConnector.__init__(self, devices, num_nodes, accelerator, strategy, plugins, precision, sync_batchnorm, benchmark, use_distributed_sampler, deterministic)
    147     self._accelerator_flag = self._choose_gpu_accelerator_backend()
    149 self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
--> 150 self._set_parallel_devices_and_init_accelerator()
    152 # 3. Instantiate ClusterEnvironment
    153 self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

File ~/miniforge3/envs/cdrl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:382, in _AcceleratorConnector._set_parallel_devices_and_init_accelerator(self)
    376 if not accelerator_cls.is_available():
    377     available_accelerator = [
    378         acc_str
    379         for acc_str in self._accelerator_types
    380         if AcceleratorRegistry[acc_str]["accelerator"].is_available()
    381     ]
--> 382     raise MisconfigurationException(
    383         f"`{accelerator_cls.__qualname__}` can not run on your system"
    384         " since the accelerator is not available. The following accelerator(s)"
    385         " is available and can be passed into `accelerator` argument of"
    386         f" `Trainer`: {available_accelerator}."
    387     )
    389 self._set_devices_flag_if_auto_passed()
    390 self._devices_flag = accelerator_cls.parse_devices(self._devices_flag)

MisconfigurationException: `MPSAccelerator` can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].

Environment

Current environment
(cdrl) adam2392@adams-mbp-6 causal-component-analysis % python collect_env_details.py
/Users/adam2392/Documents/causal-component-analysis/collect_env_details.py:24: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  import pkg_resources
<details>
  <summary>Current environment</summary>

* CUDA:
	- GPU:               None
	- available:         False
	- version:           None
* Lightning:
	- lightning-utilities: 0.11.2
	- pytorch-lightning: 2.2.4
	- torch:             2.3.0
	- torchmetrics:      1.4.0
* Packages:
	- aiohttp:           3.9.5
	- aiosignal:         1.3.1
	- appnope:           0.1.4
	- asttokens:         2.4.1
	- async-timeout:     4.0.3
	- attrs:             23.2.0
	- causal-component-analysis: 0.0.0
	- certifi:           2024.2.2
	- charset-normalizer: 3.3.2
	- click:             8.1.7
	- colorama:          0.4.6
	- comm:              0.2.2
	- contourpy:         1.2.1
	- cycler:            0.12.1
	- debugpy:           1.8.1
	- decorator:         5.1.1
	- docker-pycreds:    0.4.0
	- exceptiongroup:    1.2.1
	- executing:         2.0.1
	- filelock:          3.14.0
	- fonttools:         4.51.0
	- frozenlist:        1.4.1
	- fsspec:            2024.3.1
	- gitdb:             4.0.11
	- gitpython:         3.1.43
	- idna:              3.7
	- ipykernel:         6.29.4
	- ipython:           8.24.0
	- jedi:              0.19.1
	- jinja2:            3.1.4
	- joblib:            1.4.2
	- jupyter-client:    8.6.1
	- jupyter-core:      5.7.2
	- kiwisolver:        1.4.5
	- lightning-utilities: 0.11.2
	- markupsafe:        2.1.5
	- matplotlib:        3.8.4
	- matplotlib-inline: 0.1.7
	- mpmath:            1.3.0
	- multidict:         6.0.5
	- nest-asyncio:      1.6.0
	- networkx:          3.3
	- normflows:         1.7.3
	- numpy:             1.26.4
	- packaging:         24.0
	- pandas:            2.2.2
	- parso:             0.8.4
	- pexpect:           4.9.0
	- pillow:            10.3.0
	- pip:               24.0
	- platformdirs:      4.2.1
	- pretty-errors:     1.2.25
	- prompt-toolkit:    3.0.43
	- protobuf:          4.25.3
	- psutil:            5.9.8
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- pygments:          2.18.0
	- pyparsing:         3.1.2
	- python-dateutil:   2.9.0.post0
	- pytorch-lightning: 2.2.4
	- pytz:              2024.1
	- pyyaml:            6.0.1
	- pyzmq:             26.0.3
	- requests:          2.31.0
	- scikit-learn:      1.4.2
	- scipy:             1.13.0
	- seaborn:           0.13.2
	- sentry-sdk:        2.1.1
	- setproctitle:      1.3.3
	- setuptools:        69.5.1
	- six:               1.16.0
	- smmap:             5.0.1
	- stack-data:        0.6.3
	- sympy:             1.12
	- threadpoolctl:     3.5.0
	- torch:             2.3.0
	- torchmetrics:      1.4.0
	- tornado:           6.4
	- tqdm:              4.66.4
	- traitlets:         5.14.3
	- typing-extensions: 4.11.0
	- tzdata:            2024.1
	- urllib3:           2.2.1
	- wandb:             0.17.0
	- wcwidth:           0.2.13
	- wheel:             0.43.0
	- yarl:              1.9.4
* System:
	- OS:                Darwin
	- architecture:
		- 64bit
		- 
	- processor:         arm
	- python:            3.10.14
	- release:           23.4.0
	- version:           Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:41 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8103

</details>

More info

I looked at the following link, but I believe I am in the right conda environment, so this led to a dead end for me at least.