google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models

Home Page:https://ai.google.dev/gemma

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot run on v4-16 worker 0 TPU VM: "Failed to get global TPU topology"

markusheimerl opened this issue · comments

markusheimerl@t1v-n-a16d1e4e-w-0:~/gimli$ cd ~/gemma_cktp/ && curl -o archive.tar.gz "https://storage.googleapis.com/kaggle-models-data/5305/11357/bundle/archive.tar.gz?X-Goog-Algorithm=GOOG4-RSA-SHA256..." && tar -xf archive.tar.gz && cd ~/gimli
markusheimerl@t1v-n-a16d1e4e-w-0:~/gimli$ cd ../gemma_pytorch/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ VARIANT=2b
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ CKPT_PATH=/home/markusheimerl/gemma_ckpt/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ sudo usermod -aG docker $USER
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ newgrp docker
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ DOCKER_URI=gemma_xla:${USER}
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
[+] Building 109.0s (19/19) FINISHED                                                                                                          
 => [internal] load build definition from xla.Dockerfile                                                                                 0.0s
 => => transferring dockerfile: 1.36kB                                                                                                   0.0s
 => [internal] load .dockerignore                                                                                                        0.0s
 => => transferring context: 2B                                                                                                          0.0s
 => [internal] load metadata for us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128                   0.5s
 => [internal] load build context                                                                                                        0.1s
 => => transferring context: 6.49MB                                                                                                      0.1s
 => [ 1/14] FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128@sha256:5851322d5728f4b43f6f068f  45.8s
 => => resolve us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128@sha256:5851322d5728f4b43f6f068fa5c  0.0s
 => => sha256:577ff23cfe55ac8872bc433ce99971a34011e7a15f7c8afa3d6492c78d6d23e5 15.76MB / 15.76MB                                         0.5s
 => => sha256:5851322d5728f4b43f6f068fa5c69444db370f2cac8222183036666971f41846 4.12kB / 4.12kB                                           0.0s
 => => sha256:d1da99c2f14827498c4a9bb3623ae909b44564bdabad1802f064169069df81fb 55.06MB / 55.06MB                                         0.9s
 => => sha256:986e2cf4d9a25b7f49a2703932cad3cda01a6382bd6d38d902ad163bcc40af66 12.37kB / 12.37kB                                         0.0s
 => => sha256:c7b1e60e9d5a0f16eb1f998245666f7a64a44f8b1f2317bd31e8a658150c23d3 54.60MB / 54.60MB                                         1.3s
 => => sha256:beefab36cbfedf8896b5f9f0bc33336fa13c0f01a8cb2333128dd247895a5f3b 196.88MB / 196.88MB                                       3.3s
 => => extracting sha256:d1da99c2f14827498c4a9bb3623ae909b44564bdabad1802f064169069df81fb                                                1.1s
 => => sha256:de3224efe7269100000f1d5f451a8a6e5320b18160642c38bb97326907ecddea 6.29MB / 6.29MB                                           1.6s
 => => sha256:610099c6791eacee1e5a6d88d9ffc593485db3cc1e9bbdb9f7d112fa8fdd2725 17.54MB / 17.54MB                                         1.7s
 => => sha256:2c692cd1c1ae41798a701af5700224eea39e19736623388748e3a9a47782ba85 243B / 243B                                               1.8s
 => => sha256:67440d657e4fe10bca66a9fba87510371db3891e394766a7fdebaa8d19c6a062 2.85MB / 2.85MB                                           3.3s
 => => extracting sha256:577ff23cfe55ac8872bc433ce99971a34011e7a15f7c8afa3d6492c78d6d23e5                                                0.3s
 => => sha256:9e0ab466566e8a57436fd287a230ca562cc8aa0dd531504a2bae86f46c5d400a 130B / 130B                                               3.5s
 => => sha256:e47d3f9f3b24251c2ac7d5b04932ff23bf0099ec7c8676b6ebf86f061e1012fc 7.50kB / 7.50kB                                           3.5s
 => => sha256:9df685ee4175698291626f61c3afdd6618cdcd426b0273c4888b04607b972085 105.02MB / 105.02MB                                       4.9s
 => => sha256:3824c472a674d4e015ce79c8435392edc09924cca442f483835bbe9eae9ea52f 572.91MB / 572.91MB                                      11.1s
 => => sha256:03551a4901c735361e9e7e447828ecf77beeb2336f49fb788fd907cdb5fca972 153B / 153B                                               3.7s
 => => extracting sha256:c7b1e60e9d5a0f16eb1f998245666f7a64a44f8b1f2317bd31e8a658150c23d3                                                1.3s
 => => sha256:8ea82a97bc6ae43a7aed49d861ce05c3ed9757801016770a1101e784a5e5bc45 125.35MB / 125.35MB                                       5.9s
 => => sha256:d408b33f81ce05b78eed03e23d0081e7cdb3972c57c1103565f04f7332ed87fd 375.47MB / 375.47MB                                      15.7s
 => => sha256:3a387ede7ef122b7ad44078e16b8df873c87fa29cb1ef20e225b480be4769d34 201.39MB / 201.39MB                                      12.7s
 => => extracting sha256:beefab36cbfedf8896b5f9f0bc33336fa13c0f01a8cb2333128dd247895a5f3b                                                3.9s
 => => extracting sha256:de3224efe7269100000f1d5f451a8a6e5320b18160642c38bb97326907ecddea                                                0.2s
 => => extracting sha256:610099c6791eacee1e5a6d88d9ffc593485db3cc1e9bbdb9f7d112fa8fdd2725                                                0.4s
 => => extracting sha256:2c692cd1c1ae41798a701af5700224eea39e19736623388748e3a9a47782ba85                                                0.0s
 => => sha256:9407cf7758b440fb6f94e6159ac5e30436976a89995ea3f49bb22079ba9f206c 150B / 150B                                              15.9s
 => => sha256:7df537d35e3203cfb1a67c224e5b7f7769c6f47e7024705c26ce4a387402baad 653.48MB / 653.48MB                                      24.5s
 => => extracting sha256:67440d657e4fe10bca66a9fba87510371db3891e394766a7fdebaa8d19c6a062                                                0.2s
 => => extracting sha256:9e0ab466566e8a57436fd287a230ca562cc8aa0dd531504a2bae86f46c5d400a                                                0.0s
 => => extracting sha256:9df685ee4175698291626f61c3afdd6618cdcd426b0273c4888b04607b972085                                                4.9s
 => => extracting sha256:e47d3f9f3b24251c2ac7d5b04932ff23bf0099ec7c8676b6ebf86f061e1012fc                                                0.0s
 => => extracting sha256:3824c472a674d4e015ce79c8435392edc09924cca442f483835bbe9eae9ea52f                                                6.1s
 => => extracting sha256:03551a4901c735361e9e7e447828ecf77beeb2336f49fb788fd907cdb5fca972                                                0.0s
 => => extracting sha256:8ea82a97bc6ae43a7aed49d861ce05c3ed9757801016770a1101e784a5e5bc45                                                0.6s
 => => extracting sha256:3a387ede7ef122b7ad44078e16b8df873c87fa29cb1ef20e225b480be4769d34                                                0.9s
 => => extracting sha256:d408b33f81ce05b78eed03e23d0081e7cdb3972c57c1103565f04f7332ed87fd                                                7.6s
 => => extracting sha256:9407cf7758b440fb6f94e6159ac5e30436976a89995ea3f49bb22079ba9f206c                                                0.0s
 => => extracting sha256:7df537d35e3203cfb1a67c224e5b7f7769c6f47e7024705c26ce4a387402baad                                                3.1s
 => [ 2/14] RUN apt-get update                                                                                                          39.2s
 => [ 3/14] RUN apt-get install -y --no-install-recommends apt-utils                                                                     1.8s
 => [ 4/14] RUN apt-get install -y --no-install-recommends curl                                                                          2.6s
 => [ 5/14] RUN apt-get install -y --no-install-recommends wget                                                                          1.4s
 => [ 6/14] RUN apt-get install -y --no-install-recommends git                                                                           1.4s
 => [ 7/14] RUN python3 -m pip install --upgrade pip                                                                                     3.3s
 => [ 8/14] RUN pip install fairscale==0.4.13                                                                                            5.8s
 => [ 9/14] RUN pip install numpy==1.24.4                                                                                                1.4s
 => [10/14] RUN pip install immutabledict==4.1.0                                                                                         1.5s
 => [11/14] RUN pip install sentencepiece==0.1.99                                                                                        1.7s
 => [12/14] COPY . /workspace/gemma/                                                                                                     0.1s
 => [13/14] WORKDIR /workspace/gemma/                                                                                                    0.0s
 => [14/14] RUN pip install -e .                                                                                                         2.2s
 => exporting to image                                                                                                                   0.3s
 => => exporting layers                                                                                                                  0.3s
 => => writing image sha256:f256970b444877dc3e1bde548f82915bc2a8965542ef6e6e5c60c8f0497dfca1                                             0.0s
 => => naming to docker.io/library/gemma_xla:markusheimerl                                                                               0.0s
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker run -t --rm --shm-size 4gb -e PJRT_DEVICE=TPU -v ${CKPT_PATH}:/tmp/ckpt ${DOCKER_URI} python scripts/run_xla.py --ckpt=/tmp/ckpt --variant="${VARIANT}" --quant
usage: run_xla.py [-h] --ckpt CKPT [--variant {2b,7b}] [--output_len OUTPUT_LEN] [--seed SEED] [--quant] [--prompt PROMPT]
run_xla.py: error: argument --variant: invalid choice: '' (choose from '2b', '7b')
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ echo $VARIANT

markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ VARIANT=2b
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ CKPT_PATH=/home/markusheimerl/gemma_ckpt/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ DOCKER_URI=gemma_xla:${USER}
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker run -t --rm --shm-size 4gb -e PJRT_DEVICE=TPU -v ${CKPT_PATH}:/tmp/ckpt ${DOCKER_URI} python scripts/run_xla.py --ckpt=/tmp/ckpt --variant="${VARIANT}" --quant
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
Error in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized
    _XLAC._prepare_to_exit()
    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initializedRuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized

    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 58, in _run_thread_per_device
    initializer_fn(local_rank, local_world_size)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 117, in initialize_multiprocess
    devices = xm.get_xla_supported_devices()
  File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 99, in get_xla_supported_devices
    xla_devices = _DEVICES.value
  File "/usr/local/lib/python3.8/site-packages/torch_xla/utils/utils.py", line 29, in value
    self._value = self._gen_fn()
  File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 20, in <lambda>
    _DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "scripts/run_xla.py", line 259, in <module>
    main(args)
  File "scripts/run_xla.py", line 231, in main
    xmp.spawn(
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ 

Hi @markusheimerl, thanks for the issue! It looks like you are setting CKPT_PATH=/home/markusheimerl/gemma_ckpt/. This should be unrelated to the underlying failure you are experiencing, but the CKPT_PATH should be the path to the actual weights, not the directory.

It looks to me like this is a Torch XLA issue. It is possible that this can be fixed by using a newer version of the base container here. If not, maybe we need to put in an issue with Torch XLA.

What I would recommend first is trying us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226 and seeing if that fixes the issue. Otherwise, consider reaching out to the Torch XLA team. I can also see if I can get access to a VM with that topology to try to replicate your error.

Hi @markusheimerl , it seems you are using v4-16 TPU which 2 host VMs. This multi-host setup is currently not supported.

To test it on TPU, I suggest you try to run it on v4-8 / v5e-8 which is a single-host TPU architecture and has 1 VM. You should be able to run the command on v4-8 / v5e-8 out-of-the-box.

Hi @michaelmoynihan, I also have the Failed to get global TPU topology on tpu v4-8, so I followed your advice:

What I would recommend first is trying us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226 and seeing if that fixes the issue. Otherwise, consider reaching out to the Torch XLA team. I can also see if I can get access to a VM with that topology to try to replicate your error.

So I made the Docker file with contents

FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226

RUN pip install datasets peft transformers trl

Then I ran this:

sudo docker build -t my-tpu-pytorch-image .
sudo docker run -v /home/me/finetune:/workspace my-tpu-pytorch-image python /workspace/train.py

where train.py is this script for training gemma7b https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py
In the result I got

(v_xla) me@t1v-n-w-0:~/finetune$ sudo docker run -v /home/me/finetune:/workspace my-tpu-pytorch-image python /workspace/train.py
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1709816278.547533       1 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.8/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1709816278.547628       1 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1709816278.547636       1 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
Traceback (most recent call last):
  File "/workspace/train.py", line 15, in <module>
    device = xm.xla_device()
  File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 211, in xla_device
    return runtime.xla_device(n, devkind)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 88, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 117, in xla_device
    return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.

There is again the same error INTERNAL: Failed to get global TPU topology., but i also see that there is something wrong with JPRT. I will try that on reproduce that another env.

I tried to run this script on tpu v3-8 and with slight modifications of the script (I lowered the model to Gemma-2b - because of resource_exhausted bug) could start my script with command (without docker)

python train.py

The script is working, looks like i was using wrong vm version when creating TPU, and I forgot about setting environment variables
Correct way to create tpu v4-8

gcloud compute tpus tpu-vm create myname --zone=my-zone --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0

and add this env var

PJRT_DEVICE=TPU XLA_USE_SPMD=1