google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issues running decode example from readme

MicPie opened this issue · comments

Running

python3 MaxText/decode.py MaxText/configs/base.yml run_name=MY_JOB_NAME

from https://github.com/google/maxtext?tab=readme-ov-file#getting-started-local-development-for-single-host leads to the following error:

I0213 14:26:13.638493 140101499770880 logging_logger.py:49] Constructing tf.data.Dataset c4 for split validation, from gs://xyz_tpu/c4/en/3.0.1
Model path: assets/tokenizer
No existing checkpoints found, not restoring checkpoint.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/MMP/maxtext/MaxText/decode.py", line 278, in <module>
    app.run(main)
  File "/home/MMP/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/MMP/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/MMP/maxtext/MaxText/decode.py", line 275, in main
    decode_loop(pyconfig.config)
  File "/home/MMP/maxtext/MaxText/decode.py", line 191, in decode_loop
    kv_cache_annotations = max_utils.get_kv_cache_annotations(model, config, rng, mesh)
  File "/home/MMP/maxtext/MaxText/max_utils.py", line 539, in get_kv_cache_annotations
    abstract_state = jax.eval_shape(init_kv_cache_partial)
  File "/home/MMP/maxtext/MaxText/max_utils.py", line 531, in init_kv_cache
    model_vars = model.init({'params': rng, 'dropout': rng, 'aqt': rng},
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 349, in __call__
    logits = self.decoder(
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 241, in __call__
    y, _ = nn.scan(
  File "/home/MMP/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 148, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/MMP/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 120, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 93, in __call__
    attention_lnx = attention_layer(
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 849, in __call__
    out = attention_op(query, key, value, decoder_segment_ids, model_mode)
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 634, in __call__
    prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 169, in apply_attention
    return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 242, in tpu_flash_attention
    x = wrap_flash_attention(query, key, value, decoder_segment_ids)
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 235, in wrap_flash_attention
    return jax.vmap(splash_kernel)(query,key,value, segment_ids = decoder_segment_ids)
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2303, in __call__
    return _splash_attention(
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2269, in _splash_attention
    return _splash_attention_custom(
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 1176, in _splash_attention_custom
    return _splash_attention_forward(  # pytype: disable=wrong-arg-types
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 913, in _splash_attention_forward
    raise ValueError(f"{bkv=} must be a multiple of {NUM_LANES}.")
ValueError: bkv=64 must be a multiple of 128.
2024-02-13 14:26:15.940961: I external/xla/xla/pjrt/distributed/client.cc:134] Distributed task shutdown initiated.
2024-02-13 14:26:15.941302: I external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:1193] Shutdown barrier in coordination service has passed.
2024-02-13 14:26:15.941335: I external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:684] /job:jax_worker/replica:0/task:0 has disconnected from coordination service.
2024-02-13 14:26:15.941496: I external/xla/xla/pjrt/distributed/client.cc:136] Distributed task shutdown result: OK
2024-02-13 14:26:15.941510: I external/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:166] Cancelled call to retrieve preemption notice. This is expected upon program shutdown.
2024-02-13 14:26:15.941699: I external/xla/xla/pjrt/distributed/service.cc:118] Jax service shutting down
2024-02-13 14:26:15.942328: I external/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:139] Preemption sync protocol cancelled by notifier: CANCELLED: Preemption notifier is being deleted.. This is expected during program shutdown.

Before that I only run this to set up everything:

git clone https://github.com/google/maxtext.git
bash setup.sh
bash download_dataset.sh tpu-cluster gs://xyz_tpu

And I added the correct dataset_path to MaxText/configs/base.yml and I'm running this on a v5p-8.

Sorry -- set attention=dot_product.