GPT-J-6B Inference Demo notebook giving errors when cores_per_replica=1
batrasakshi opened this issue · comments
I am trying out the demo notebook without tpu backend and updated
"cores_per_replica": 1,
"per_replica_batch": 1,
in params
While executing
network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)
i am getting error :
Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
Full error :
AssertionError Traceback (most recent call last)
File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:217, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt)
216 try:
--> 217 unsharded = _unshard()
218 except AssertionError:
File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard()
208 unsharded.append(x)
--> 210 assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
211 device_index += 1
AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
During handling of the above exception, another exception occurred:
AssertionError Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)
File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:222, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt)
220 del pytree['opt_state']
221 old_flattened, structure = jax.tree_flatten(pytree)
--> 222 unsharded = _unshard()
224 loaded_pytree = jax.tree_unflatten(structure, unsharded)
226 if not load_opt:
File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard()
207 x = reshard(x, old_flattened[device_index].shape)
208 unsharded.append(x)
--> 210 assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
211 device_index += 1
213 print(f"read from disk/gcs in {time.time() - start:.06}s")
AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
Worked by specifying shards_in