google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pipeline Parallelism: USE_REPEATED_LAYERS bug

abhinavgoel95 opened this issue · comments

Hello!

I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I notice that USE_REPEATED_LAYERS=True helps speed up compilation and also reduces the memory requirement. However, when I set USE_REPEATED_LAYERS=True with Pipeline Parallelism, I get the following error.

System:

8X NVIDIA A100-SXM 80 GB

Gin Configs:

from __gin__ import dynamic_registration

import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp

MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2

MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
  USE_REPEATED_LAYER = True
  MAX_SEQ_LEN = %MAX_SL
  NUM_LAYERS = 12
  NUM_HEADS = 12
  MODEL_DIMS = 768
  HIDDEN_DIMS = 3072
  DIMS_PER_HEAD = 64
  VOCAB_SIZE = 51200
  TRAINABLE_POSITION_EMB = True
  TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
  ACTIVATION_CLS = @activations.GELU.HParams()
  PACKED_INPUT = True
  USE_BIAS = False
  MAX_STEPS=%MAX_STEPS
  INIT_STD = 0.023
  EVAL_INTERVAL_STEPS = 100
  NUM_STAGES = %NUM_STAGES
  NUM_MICROBATCHES = 1
  ICI_MESH_SHAPE = %ICI_MESH_SHAPE
  FPROP_DTYPE = @jnp.bfloat16
  SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
  CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
  EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS

OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
  beta1 = 0.9
  beta2 = 0.95
  learning_rate = 6e-4
  epsilon_root = 0.0
  epsilon = 1e-8
  weight_decay = 0.1
  clip_threshold = 1.0
  clip_gradient_norm_to_value = 5.0


SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
  warmup_steps = 636
  decay_start = 637
  decay_end = 500000
  min_ratio = 0.1
  max = 1.0

DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
  MAX_SEQ_LEN = %MAX_SL
  PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE

## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
  model = %MODEL
  dataset = %DATASET
  optimizer = %OPTIMIZER
  scheduler = %SCHEDULER
  
train_script.run:
  experiment_config = %EXPERIMENT

Command:

#! /bin/bash

set -x

PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
    --exp=tasks.lm.params.c4.PileSpmdAdam \
    --gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
    --tfds_data_dir="/pax/datasets" \
    --vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
    --pmap_use_tensorstore=True \
    --job_log_dir=/logs/ \
    --alsologtostderr 

set +x

Error:

Traceback (most recent call last):
  File "/pax/paxml/paxml/main.py", line 631, in <module>
    app.run(main, flags_parser=_gin_flags_parser)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/pax/paxml/paxml/main.py", line 588, in main
    run_with_gin()
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/local/lib/python3.8/dist-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/pax/paxml/paxml/main.py", line 535, in run
    run_experiment(
  File "/pax/paxml/paxml/main.py", line 290, in run_experiment
    train.train_and_evaluate(
  File "/pax/paxml/paxml/train.py", line 271, in train_and_evaluate
    train_and_evaluate_spmd_model(task_p, train_input_p, job_log_dir,
  File "/pax/paxml/paxml/train.py", line 851, in train_and_evaluate_spmd_model
    vars_weight_params = jax_task.model.abstract_init_with_metadata(
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/transforms.py", line 1320, in wrapped_fn
    return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 353, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 652, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/pax/praxis/praxis/base_layer.py", line 1231, in abstract_init_with_metadata
    variables_abstract = jax.eval_shape(init_fn, rngs)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 3024, in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 662, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1929, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/layers/pipeline.py", line 217, in force_init
    body_init_fn(self.body, None)
  File "/pax/praxis/praxis/layers/pipeline.py", line 162, in fn
    model.force_init(None)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
TypeError: force_init() takes 1 positional argument but 2 were given
  In call to configurable 'run' (<function run at 0x7fedd131ab80>)

Would you have any suggestions on how to fix this?

Hi Abhinav,

There is likely an issue in init for nested pipeline/repeat layers. However, we currently don't recommend using this setup: pipeline+repeat will have 2-level jax.scan which could cause unexpected remat behaviors and you may end up computing more than needed.

For the memory concern, you can alternatively also play with checkpoint_policy:
praxis/layers/pipeline.py; line 112
checkpoint_policy: AutodiffCheckpointType = AutodiffCheckpointType.SAVE_ITERATION_INPUT

You can try a few different options. It's a tradeoff between the buffers at the stage boundaries and the buffers within each stage. E.g., SAVE_NOTHING has the smallest boundary buffers, but the largest in-stage buffers.

(you could ignore SAVE_ITERATION_INPUT since all configs will be combined with iterattion input being saved)

Also, please keep in mind that pipeline on XLA:GPU still has some TODOs at the compiler backend. There are some missing HLO optimizations to be added to the GPU compiler passes.

This is already fixed. We no longer have force_init