The PyTorch version is incorrect.
Doraemonzzz opened this issue · comments
Thank you for your work, this is a great project. However, I encountered some environment issues while running it. I have tried 2.4.0.dev20240419+cu121
, 2.4.0.dev20240612+cu121
, and 2.5.0.dev20240617+cu121
, but all of these resulted in errors. Could you please provide the correct torch version that can be used with the main branch? Thank you.
I didn't get the code up and running either. The test model using debug_model.toml ran successfully. But when trying to train llama3 using llama3_8b.toml, I had several ImportErrors:
cannot import name 'Partial' from 'torch.distributed._tensor'
cannot import name 'CheckpointPolicy’ from 'torch.utils.checkpoint'
etc.
I tried torch-2.4.0.dev20240412 (which torchtitan is verified on according to README) from https://download.pytorch.org/whl/nightly/cu118/torch-2.4.0.dev20240412%2Bcu118-cp310-cp310-linux_x86_64.whl and several other pytorch nightly builds with no luck.
It seems we need a specific version of pytorch nightly.
This is due to the PyTorch PR was reverted:
pytorch/pytorch#125795
You can manually patch your code based on the change in this 2 PRs: #397 #401
Thank you for your response. I am currently encountering the following issue:
ModuleNotFoundError: No module named 'torch.distributed.pipelining
The torch version is "2.4.0.dev20240419+cu121".
Just as the previous issue, this is caused by active API changing during release on PyTorch side. Since PyTorch is preparing the next minor release, the API is actively changing recently. One suggestion is to wait for those change getting into nightly build and use nightly build pytorch, or compile PyTorch using the latest github branch. We understand the inconvenience but right now we couldn't come up with a better solution... :-(
update: I believe 2.5.0.dev20240617+cu121 should be enough new to include the API change.
Thank you for the version you provided. We have successfully run the code on version 2.5.0.dev20240617+cu121. By the way, could you please explain the purpose of the following code in norms.py? It seems that we couldn't find the corresponding documentation.
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@Doraemonzzz Thanks for your interest in our new experimental feature (this is why there's no official doc about it) local_map
. In short, this decorator allows user to call the decorated function on DTensor
, with user-specified sharding specification (i.e. placements). See pytorch/pytorch#123676 for details.
Since pytorch/pytorch#125795 is re-landed, this issue should be gone in recent PyTorch nightly soon.
Thank you for your response. I am currently encountering the following issue:
ModuleNotFoundError: No module named 'torch.distributed.pipelining
The torch version is "2.4.0.dev20240419+cu121".
Hi, I also encountered this issue. I installed all the dependencies as recommended:
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
and my resulting torch version is 2.3.1+cu121. Any suggestions? How to switch to 2.5.0.dev20240617+cu121 ? I cannot find it online
Update: problem solved by using
pip3 install --pre torch==2.5.0.dev20240617 --index-url https://download.pytorch.org/whl/nightly/cu121
Encountering the same issue.
[rank0]: File "/opt/ml/code/torchtitan/parallelisms/parallelize_llama.py", line 415, in apply_ac
[rank0]: transformer_block = checkpoint_wrapper(transformer_block, ac_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/ml/code/torchtitan/parallelisms/parallelize_llama.py", line 50, in checkpoint_wrapper
[rank0]: from torch.utils.checkpoint import (
[rank0]: ImportError: cannot import name 'CheckpointPolicy' from 'torch.utils.checkpoint' (/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py)