kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`Incompatible checkpoints` error when running `slim_model.py`

danyaljj opened this issue · comments

$ python3 slim_model.py --config configs/6B_roto_256.json
WARNING: Logging before InitGoogle() is written to STDERR
I0321 19:57:29.230169   28674 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax devices: 1
jax runtime initialized in 0.0365555s
using checkpoint 383500
/home/danielk/.local/lib/python3.8/site-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Total parameters: 6050886880
read from disk/gcs in 129.839s
Traceback (most recent call last):
  File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 164, in read_ckpt
    unsharded = _unshard(shards, old_flattened)
  File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 161, in _unshard
    assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "slim_model.py", line 69, in <module>
    network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
  File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 169, in read_ckpt
    unsharded = _unshard(shards, old_flattened)
  File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 161, in _unshard
    assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (1,) vs (1, 4096)

These are the shards that I downloaded from here: https://mystic.the-eye.eu/public/AI/GPT-J-6B/previous_checkpoints/step_384500/

Update: based on the following warning:

I0321 19:57:29.230169   28674 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax devices: 1

it seems that some other processes (probably my earlier runs) are using TPUs. So I killed the process shown in sudo lsof -w /dev/accel0.