kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GPT-J-6B Inference Demo notebook giving errors when cores_per_replica=1

batrasakshi opened this issue · comments

I am trying out the demo notebook without tpu backend and updated

       "cores_per_replica": 1,
       "per_replica_batch": 1,

in params

While executing
network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)
i am getting error :
Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)

Full error :

AssertionError                            Traceback (most recent call last)
File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:217, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt)
    216 try:
--> 217     unsharded = _unshard()
    218 except AssertionError:

File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard()
    208 unsharded.append(x)
--> 210 assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
    211 device_index += 1

AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)

File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:222, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt)
    220     del pytree['opt_state']
    221     old_flattened, structure = jax.tree_flatten(pytree)
--> 222     unsharded = _unshard()
    224 loaded_pytree = jax.tree_unflatten(structure, unsharded)
    226 if not load_opt:

File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard()
    207             x = reshard(x, old_flattened[device_index].shape)
    208         unsharded.append(x)
--> 210         assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
    211         device_index += 1
    213 print(f"read from disk/gcs in {time.time() - start:.06}s")

AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)

Worked by specifying shards_in