Torch_XLA crashes with 143 on the Larger Pods such as V3-128
jianguoz opened this issue Β· comments
π Bug
@JackCaoG When I run a general xla_dist
command such as
sudo python3 -m torch_xla.distributed.xla_dist --tpu=${TPU_NAME} --restart-tpuvm-pod-server -- pkill -f 'python'
The TPU VM shows
2023-01-16 20:44:15 172.16.80.29 [9] SCP: Attempting to connect to worker 9...
2023-01-16 20:44:18 172.16.80.8 [14] Terminated
2023-01-16 20:44:18 172.16.80.121 [8] Terminated
2023-01-16 20:44:18 172.16.80.32 [10] Terminated
2023-01-16 20:44:18 172.16.80.31 [2] Terminated
2023-01-16 20:44:18 172.16.80.34 [15] Terminated
2023-01-16 20:44:18 172.16.80.28 [1] Terminated
2023-01-16 20:44:18 172.16.80.81 [13] Terminated
2023-01-16 20:44:18 172.16.80.37 [5] Terminated
2023-01-16 20:44:18 172.16.80.7 [6] Terminated
2023-01-16 20:44:18 172.16.80.38 [4] Terminated
2023-01-16 20:44:18 172.16.80.2 [12] Terminated
2023-01-16 20:44:18 172.16.80.24 [3] Terminated
2023-01-16 20:44:18 172.16.80.25 [7] Terminated
2023-01-16 20:44:18 172.16.80.29 [9] Terminated
2023-01-16 20:44:18 172.16.80.23 [0] Terminated
2023-01-16 20:44:18 172.16.80.35 [11] Terminated
Process Process-1:
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 520, in _run_cmd
self._start_run(script_map)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 514, in _start_run
xu.parallel_work(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 293, in parallel_work
return [res for res in results] # Iterating to re-raise any exceptions
File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 293, in <listcomp>
return [res for res in results] # Iterating to re-raise any exceptions
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 500, in _run_script
raise RuntimeError(
RuntimeError: Remote command exitted with code: 143
However, the TPU pods works fine on the single node (general torch_xla commands) or examples such as
gcloud alpha compute tpus tpu-vm ssh {TPU_NAME} --zone us-east1-d --command "python commands" --worker all
In general I do not have such issues on V3-32 (i.e., smaller pods).
I think the TPU VM doesn't shut down. I also tried the nightly version, it also doesn't work.
To Reproduce
Steps to reproduce the behavior:
1.sudo python3 -m torch_xla.distributed.xla_dist --tpu=${TPU_NAME} --restart-tpuvm-pod-server -- pkill -f 'python'
2.
3.
Expected behavior
Environment
- Reproducible on XLA backend [CPU/TPU]: TPU V3-128
- torch_xla version: tpu-vm-pt-1.13
Additional context
it seems like you are using xla_dist
to run pkill -f python
for all pods because you want to do some clean up? This is not training right?
@JackCaoG I tried both the xla with the normal training commands
and the pkill -f python
, it seems that both do not work and return 143
. Normal training commands such as
TPU_NAME=tpu-v3-128
python3 -m torch_xla.distributed.xla_dist \
--tpu=${TPU_NAME} --restart-tpuvm-pod-server -- \
python3 run_glue.py \
--model_name_or_path bert-base-cased \
--dataset_name SetFit/mrpc \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--run_name mrpc_v3-128_bs-64_lr-2e-5-bert \
--output_dir /tmp/mrpc-bert/ \
--overwrite_output_dir
A side question is that do we have better ways to kill xla on all pods? It seems that Control+C
does not work well and pkill -f python
sometimes causes crash even on smaller TPU pods such as V3-32.
ctrl+c
should work, first ctrl+c
should tell the xla_dist
to exit and it will try to do the clean up. I think if you don't do ctrl+c
twice it works most of the time for me. regarding clean up,
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command=""
to run any bash command on all TPU host for a pod. You can pass pkill -f python
that way.
@JackCaoG Thanks for the pkill
suggestions. For normal running, Is the terminated with 143 error come from Google's side or xla? so far I did not successfully run any xla code on TPU pods v3-128
on europe-west4-a
.
I am trying to schedule a TPU pod on another zone and will update soon
@Liyang90 will take a look from our end. It was unclear to me what's the error. My natural guess would be that it is a OOM when we increase the pod size, but if pkill -f
also failed, I suspect the failure is somehow related to the number of ssh connections we issued.
@JackCaoG @li-yi-dong Rescheduled a new V3-128 pods with 16 ssh connections. Same terminated errors with 143 as above when I try to run commands starting with sudo python3 -m torch_xla.distributed.xla_dist \ --tpu=${TPU_NAME} --restart-tpuvm-pod-server
For a sanity check, can you see if resnet50 example with fake data works on v3-128(using the instruction in https://cloud.google.com/tpu/docs/pytorch-pods)?
@JackCaoG We create several pods and do test. Both the resent50 code and the general XLA code such as **sudo** python3 -m torch_xla.distributed.xla_dist --tpu=us-tpu-32 --restart-tpuvm-pod-server -- python -c "print(1+1)"
only work on v3-32
and it does not work on v3-64
and v3-128
. The errors are exactly same as above. We also change zones and it still does not work. Can you check the issues? Thanks:)
Hi @jianguoz could you share more about how the TPU cluster is created, and in what environment the xla_dist
module was called?
Hi @Liyang90, this is how we create the TPU cluster (Note I removed --network and --subnetwork information from my yaml file)
export TPU_NAME=west-tpu-128 # change to your TPU name
export ZONE=europe-west4-a # change to your TPU zone
export ACCELERATOR_TYPE=v3-128 # you can also try out larger TPU pods
export RUNTIME_VERSION=tpu-vm-pt-1.13 # the XLA FSDP interface is supported in PyTorch/XLA
# add --network --subnetwork if use specific networks
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone ${ZONE} \
--accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} --tags=tpu-vm \
--metadata startup-script='$#! /bin/bash
pip install timm
pip install ftfy
pip install regex
pip install wandb
pip install transformers
pip3 install -U scikit-learn
pip install nltk
pip install python-dotenv==0.20.0
pip install sentencepiece
pip install jsonlines
pip install accelerate
pip install fire==0.4.0
pip install evaluate
pip install google-cloud-storage
EOF'
The xla_dist is called under the default python3 environment (where torch_xla is installed).
Do you ssh to the worker 0 to run the xla_dist
module as indicated here: https://cloud.google.com/tpu/docs/pytorch-pods?
@Liyang90 Yes! I login to worker 0 by default.
Login
gcloud alpha compute tpus tpu-vm ssh {TPU-NAME} --zone {Zone} --project {Project}
sudo config
sudo gcloud compute config-ssh
Below sudo
run xla_dist returns 143:
sudo python3 -m torch_xla.distributed.xla_dist --tpu={TPU-NAME} --restart-tpuvm-pod-server -- python3 -c 'print(1+1)'
However, I found if I remove sudo
, and after setting gcloud compute config-ssh
, it works without 143 error. Can you check the the potential issues or reasons here? If possible, can you add sudo access to xla?
Interesting, I've not tried using sodu
with those commands. Can you confirm that without sudo
the things work as you expected? And what's the reason for the need of sudo
?
Hi @Liyang90, I found the potential reason is that we need root user permissions to run xla_dist across multiple pods on Google cloud TPUs, which is sometimes different on GPUs.
Thanks for your help! I closed this issue:)