Pipeline Parallelism: USE_REPEATED_LAYERS bug
abhinavgoel95 opened this issue · comments
Hello!
I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I notice that USE_REPEATED_LAYERS=True helps speed up compilation and also reduces the memory requirement. However, when I set USE_REPEATED_LAYERS=True with Pipeline Parallelism, I get the following error.
System:
8X NVIDIA A100-SXM 80 GB
Gin Configs:
from __gin__ import dynamic_registration
import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp
MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2
MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
USE_REPEATED_LAYER = True
MAX_SEQ_LEN = %MAX_SL
NUM_LAYERS = 12
NUM_HEADS = 12
MODEL_DIMS = 768
HIDDEN_DIMS = 3072
DIMS_PER_HEAD = 64
VOCAB_SIZE = 51200
TRAINABLE_POSITION_EMB = True
TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
ACTIVATION_CLS = @activations.GELU.HParams()
PACKED_INPUT = True
USE_BIAS = False
MAX_STEPS=%MAX_STEPS
INIT_STD = 0.023
EVAL_INTERVAL_STEPS = 100
NUM_STAGES = %NUM_STAGES
NUM_MICROBATCHES = 1
ICI_MESH_SHAPE = %ICI_MESH_SHAPE
FPROP_DTYPE = @jnp.bfloat16
SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS
OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
beta1 = 0.9
beta2 = 0.95
learning_rate = 6e-4
epsilon_root = 0.0
epsilon = 1e-8
weight_decay = 0.1
clip_threshold = 1.0
clip_gradient_norm_to_value = 5.0
SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
warmup_steps = 636
decay_start = 637
decay_end = 500000
min_ratio = 0.1
max = 1.0
DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
MAX_SEQ_LEN = %MAX_SL
PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE
## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
model = %MODEL
dataset = %DATASET
optimizer = %OPTIMIZER
scheduler = %SCHEDULER
train_script.run:
experiment_config = %EXPERIMENT
Command:
#! /bin/bash
set -x
PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
--exp=tasks.lm.params.c4.PileSpmdAdam \
--gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
--tfds_data_dir="/pax/datasets" \
--vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
--pmap_use_tensorstore=True \
--job_log_dir=/logs/ \
--alsologtostderr
set +x
Error:
Traceback (most recent call last):
File "/pax/paxml/paxml/main.py", line 631, in <module>
app.run(main, flags_parser=_gin_flags_parser)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303, in run
_run_main(main, args)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/pax/paxml/paxml/main.py", line 588, in main
run_with_gin()
File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1605, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "/usr/local/lib/python3.8/dist-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.__traceback__) from None
File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "/pax/paxml/paxml/main.py", line 535, in run
run_experiment(
File "/pax/paxml/paxml/main.py", line 290, in run_experiment
train.train_and_evaluate(
File "/pax/paxml/paxml/train.py", line 271, in train_and_evaluate
train_and_evaluate_spmd_model(task_p, train_input_p, job_log_dir,
File "/pax/paxml/paxml/train.py", line 851, in train_and_evaluate_spmd_model
vars_weight_params = jax_task.model.abstract_init_with_metadata(
File "/usr/local/lib/python3.8/dist-packages/flax/linen/transforms.py", line 1320, in wrapped_fn
return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 353, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 652, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/pax/praxis/praxis/base_layer.py", line 1231, in abstract_init_with_metadata
variables_abstract = jax.eval_shape(init_fn, rngs)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 3024, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 662, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1929, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
jax.tree_map(force, val)
File "/pax/praxis/praxis/base_layer.py", line 1167, in force
v.force_init(*args)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
jax.tree_map(force, val)
File "/pax/praxis/praxis/base_layer.py", line 1167, in force
v.force_init(*args)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
jax.tree_map(force, val)
File "/pax/praxis/praxis/base_layer.py", line 1167, in force
v.force_init(*args)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/pax/praxis/praxis/layers/pipeline.py", line 217, in force_init
body_init_fn(self.body, None)
File "/pax/praxis/praxis/layers/pipeline.py", line 162, in fn
model.force_init(None)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
jax.tree_map(force, val)
File "/pax/praxis/praxis/base_layer.py", line 1167, in force
v.force_init(*args)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
TypeError: force_init() takes 1 positional argument but 2 were given
In call to configurable 'run' (<function run at 0x7fedd131ab80>)
Would you have any suggestions on how to fix this?
Hi Abhinav,
There is likely an issue in init for nested pipeline/repeat layers. However, we currently don't recommend using this setup: pipeline+repeat will have 2-level jax.scan which could cause unexpected remat behaviors and you may end up computing more than needed.
For the memory concern, you can alternatively also play with checkpoint_policy:
praxis/layers/pipeline.py; line 112
checkpoint_policy: AutodiffCheckpointType = AutodiffCheckpointType.SAVE_ITERATION_INPUT
You can try a few different options. It's a tradeoff between the buffers at the stage boundaries and the buffers within each stage. E.g., SAVE_NOTHING has the smallest boundary buffers, but the largest in-stage buffers.
(you could ignore SAVE_ITERATION_INPUT since all configs will be combined with iterattion input being saved)
Also, please keep in mind that pipeline on XLA:GPU still has some TODOs at the compiler backend. There are some missing HLO optimizations to be added to the GPU compiler passes.
This is already fixed. We no longer have force_init