Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[fabric.example.rl] Not support torch.float64 for MPS device

swyo opened this issue · comments

Bug description

I found an error when run the example pytorch-lightning/examples/fabric/reinforcement_learning on M2 Mac (device type=mps)

Reproduce Error

reinforcement_learning git:(master) ✗ fabric run
W0617 12:53:22.541000 8107367488 torch/distributed/elastic/multiprocessing/] NOTE: Redirects are currently not supported in Windows or MacOs.
[rank: 0] Seed set to 42
Missing logger folder: logs/fabric_logs/2024-06-17_12-53-24/CartPole-v1_default_42_1718596404
set default torch dtype as torch.float32
Traceback (most recent call last):
  File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/", line 215, in <module>
  File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/", line 154, in main
    rewards[step] = torch.tensor(reward, device=device).view(-1)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

This bug is fixed by checking device.type and type casting to torch.float32 reward.

@@ -146,7 +146,7 @@ def main(args: argparse.Namespace):
             # Single environment step
             next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
             done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
-            rewards[step] = torch.tensor(reward, device=device).view(-1)
+            rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32 if device.type == 'mps' else None).view(-1)

What version are you seeing the problem on?



Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):  2.3.0
#- Lightning App Version (e.g., 0.5.2): 2.3.0
#- PyTorch Version (e.g., 2.0): 2.3.1
#- Python version (e.g., 3.9): 3.12.3
#- OS (e.g., Linux): Mac
#- CUDA/cuDNN version: MPS
#- GPU models and configuration: M2
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):