google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pipeline Parallelism: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size Fatal Python error: Aborted

abhinavgoel95 opened this issue · comments

Hello!

I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I run into some errors when NUM_MICROBATCHES > 1.

System:

8X NVIDIA A100-SXM 80 GB

Gin Configs:

from __gin__ import dynamic_registration

import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp

MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2

MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
  USE_REPEATED_LAYER = False
  MAX_SEQ_LEN = %MAX_SL
  NUM_LAYERS = 12
  NUM_HEADS = 12
  MODEL_DIMS = 768
  HIDDEN_DIMS = 3072
  DIMS_PER_HEAD = 64
  VOCAB_SIZE = 51200
  TRAINABLE_POSITION_EMB = True
  TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
  ACTIVATION_CLS = @activations.GELU.HParams()
  PACKED_INPUT = True
  USE_BIAS = False
  MAX_STEPS=%MAX_STEPS
  INIT_STD = 0.023
  EVAL_INTERVAL_STEPS = 100
  NUM_STAGES = %NUM_STAGES
  NUM_MICROBATCHES = 2
  ICI_MESH_SHAPE = %ICI_MESH_SHAPE
  FPROP_DTYPE = @jnp.bfloat16
  SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
  CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
  EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS

OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
  beta1 = 0.9
  beta2 = 0.95
  learning_rate = 6e-4
  epsilon_root = 0.0
  epsilon = 1e-8
  weight_decay = 0.1
  clip_threshold = 1.0
  clip_gradient_norm_to_value = 5.0


SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
  warmup_steps = 636
  decay_start = 637
  decay_end = 500000
  min_ratio = 0.1
  max = 1.0

DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
  MAX_SEQ_LEN = %MAX_SL
  PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE

## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
  model = %MODEL
  dataset = %DATASET
  optimizer = %OPTIMIZER
  scheduler = %SCHEDULER
  
train_script.run:
  experiment_config = %EXPERIMENT

Command:

#! /bin/bash

set -x

PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
    --exp=tasks.lm.params.c4.PileSpmdAdam \
    --gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
    --tfds_data_dir="/pax/datasets" \
    --vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
    --pmap_use_tensorstore=True \
    --job_log_dir=/logs/ \
    --alsologtostderr 

set +x

XLA Complie Time Error:

2022-10-10 16:01:05.537760: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size 
Fatal Python error: Aborted

Current thread 0x00007f5c10b73740 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 940 in backend_compile
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294 in wrapper
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 996 in compile_or_get_cached
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 3048 in from_hlo
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2890 in compile
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 815 in _pjit_call_impl
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 685 in process_primitive
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 327 in bind_with_trace
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 324 in bind
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 385 in wrapped
  File "/pax/paxml/paxml/train.py", line 1087 in train_and_evaluate_spmd_model
  File "/pax/paxml/paxml/train.py", line 271 in train_and_evaluate
  File "/pax/paxml/paxml/main.py", line 290 in run_experiment
  File "/pax/paxml/paxml/main.py", line 535 in run
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582 in gin_wrapper
  File "/pax/paxml/paxml/main.py", line 588 in main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run
  File "/pax/paxml/paxml/main.py", line 631 in <module>

There is no problem when NUM_MICROBATCHES = 1.

It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.

Tracked internally in b/253051570.

Already fixed.