[fabric.example.rl] Not support torch.float64 for MPS device
swyo opened this issue · comments
Bug description
I found an error when run the example pytorch-lightning/examples/fabric/reinforcement_learning
on M2 Mac (device type=mps)
Reproduce Error
reinforcement_learning git:(master) ✗ fabric run train_fabric.py
W0617 12:53:22.541000 8107367488 torch/distributed/elastic/multiprocessing/redirects.py:27] NOTE: Redirects are currently not supported in Windows or MacOs.
[rank: 0] Seed set to 42
Missing logger folder: logs/fabric_logs/2024-06-17_12-53-24/CartPole-v1_default_42_1718596404
set default torch dtype as torch.float32
Traceback (most recent call last):
File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/train_fabric.py", line 215, in <module>
main(args)
File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/train_fabric.py", line 154, in main
rewards[step] = torch.tensor(reward, device=device).view(-1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
This bug is fixed by checking device.type and type casting to torch.float32 reward
.
@@ -146,7 +146,7 @@ def main(args: argparse.Namespace):
# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
- rewards[step] = torch.tensor(reward, device=device).view(-1)
+ rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32 if device.type == 'mps' else None).view(-1)
What version are you seeing the problem on?
master
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.3.0
#- Lightning App Version (e.g., 0.5.2): 2.3.0
#- PyTorch Version (e.g., 2.0): 2.3.1
#- Python version (e.g., 3.9): 3.12.3
#- OS (e.g., Linux): Mac
#- CUDA/cuDNN version: MPS
#- GPU models and configuration: M2
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):