Pipeline Parallelism: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size Fatal Python error: Aborted
abhinavgoel95 opened this issue · comments
Abhinav Goel commented
Hello!
I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I run into some errors when NUM_MICROBATCHES > 1.
System:
8X NVIDIA A100-SXM 80 GB
Gin Configs:
from __gin__ import dynamic_registration
import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp
MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2
MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
USE_REPEATED_LAYER = False
MAX_SEQ_LEN = %MAX_SL
NUM_LAYERS = 12
NUM_HEADS = 12
MODEL_DIMS = 768
HIDDEN_DIMS = 3072
DIMS_PER_HEAD = 64
VOCAB_SIZE = 51200
TRAINABLE_POSITION_EMB = True
TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
ACTIVATION_CLS = @activations.GELU.HParams()
PACKED_INPUT = True
USE_BIAS = False
MAX_STEPS=%MAX_STEPS
INIT_STD = 0.023
EVAL_INTERVAL_STEPS = 100
NUM_STAGES = %NUM_STAGES
NUM_MICROBATCHES = 2
ICI_MESH_SHAPE = %ICI_MESH_SHAPE
FPROP_DTYPE = @jnp.bfloat16
SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS
OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
beta1 = 0.9
beta2 = 0.95
learning_rate = 6e-4
epsilon_root = 0.0
epsilon = 1e-8
weight_decay = 0.1
clip_threshold = 1.0
clip_gradient_norm_to_value = 5.0
SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
warmup_steps = 636
decay_start = 637
decay_end = 500000
min_ratio = 0.1
max = 1.0
DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
MAX_SEQ_LEN = %MAX_SL
PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE
## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
model = %MODEL
dataset = %DATASET
optimizer = %OPTIMIZER
scheduler = %SCHEDULER
train_script.run:
experiment_config = %EXPERIMENT
Command:
#! /bin/bash
set -x
PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
--exp=tasks.lm.params.c4.PileSpmdAdam \
--gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
--tfds_data_dir="/pax/datasets" \
--vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
--pmap_use_tensorstore=True \
--job_log_dir=/logs/ \
--alsologtostderr
set +x
XLA Complie Time Error:
2022-10-10 16:01:05.537760: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size
Fatal Python error: Aborted
Current thread 0x00007f5c10b73740 (most recent call first):
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 940 in backend_compile
File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294 in wrapper
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 996 in compile_or_get_cached
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 3048 in from_hlo
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2890 in compile
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 815 in _pjit_call_impl
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 685 in process_primitive
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 327 in bind_with_trace
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 324 in bind
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 385 in wrapped
File "/pax/paxml/paxml/train.py", line 1087 in train_and_evaluate_spmd_model
File "/pax/paxml/paxml/train.py", line 271 in train_and_evaluate
File "/pax/paxml/paxml/main.py", line 290 in run_experiment
File "/pax/paxml/paxml/main.py", line 535 in run
File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582 in gin_wrapper
File "/pax/paxml/paxml/main.py", line 588 in main
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run
File "/pax/paxml/paxml/main.py", line 631 in <module>
There is no problem when NUM_MICROBATCHES = 1.
It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.
Abhinav Goel commented
George Karpenkov commented
Tracked internally in b/253051570.
Qiao Zhang commented
Already fixed.