Error while `to_hf_weights.py`: `ValueError: cannot reshape array of size 25804800 into shape (1,4096,50400)`
danyaljj opened this issue · comments
$ python3 to_hf_weights.py --input-ckpt gs://danielk-files/gpt-j-checkpoints_slim/step_318500 --config configs/6B_roto_256.json --output-path gs://danielk-files/danielk-files/gpt-j-checkpoints_slim_hf/step_318500 --dtype fp32
to_hf_weights.py:101: UserWarning: WARNING: Dtype support other than fp16 is Experimental. Make sure to check weights after conversion to make sure dtype information is retained.
warnings.warn(
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Total parameters: 6050886880
Reading and transforming layers/shards. This may take a while.
Reading/Transforming Layers: 0%|▋ | 1/287 [00:26<2:04:31, 26.12s/it]to_hf_weights.py:384: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /Users/distiller/project/pytorch/aten/src/ATen/native/TensorShape.cpp:2318.)
x = torch.tensor(x.squeeze(0), dtype=torch_dtype).T
Reading/Transforming Layers: 1%|█▉ | 3/287 [00:26<41:19, 8.73s/it]
Traceback (most recent call last):
File "to_hf_weights.py", line 488, in <module>
save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
File "to_hf_weights.py", line 466, in save_sharded_to_hf_format
save_pytree_as_hf(
File "to_hf_weights.py", line 382, in save_pytree_as_hf
x = unshard_leave(x, leave_name, old_shape, np_dtype=np_dtype)
File "to_hf_weights.py", line 324, in unshard_leave
x = reshard(
File "to_hf_weights.py", line 244, in reshard
out = np.reshape(x, old_shape)
File "<__array_function__ internals>", line 5, in reshape
File "/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 299, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 58, in _wrapfunc
return bound(*args, **kwds)
ValueError: cannot reshape array of size 25804800 into shape (1,4096,50400)
Here is my environment:
$ pip3 list
Package Version
------------------------ ---------
absl-py 0.15.0
aiohttp 3.8.1
aiohttp-cors 0.7.0
aioredis 2.0.1
aiosignal 1.2.0
astunparse 1.6.3
async-timeout 4.0.2
attrs 21.4.0
blessings 1.7
cachetools 4.2.4
certifi 2021.10.8
charset-normalizer 2.0.12
chex 0.1.1
clang 5.0
click 8.0.4
cloudpickle 1.3.0
colorama 0.4.4
colorful 0.5.4
Deprecated 1.2.13
dm-haiku 0.0.5
dm-tree 0.1.6
einops 0.3.2
filelock 3.6.0
flatbuffers 1.12
frozenlist 1.3.0
gast 0.4.0
google-api-core 1.31.5
google-auth 1.35.0
google-auth-oauthlib 0.4.6
google-cloud-core 1.7.2
google-cloud-storage 1.36.2
google-crc32c 1.3.0
google-pasta 0.2.0
google-resumable-media 1.3.3
googleapis-common-protos 1.56.0
gpustat 0.6.0
grpcio 1.44.0
h5py 3.1.0
huggingface-hub 0.4.0
idna 3.3
importlib-metadata 4.11.3
importlib-resources 5.4.0
jax 0.2.12
jaxlib 0.1.68
jmp 0.0.2
joblib 1.1.0
jsonschema 4.4.0
keras 2.6.0
Keras-Preprocessing 1.1.2
Markdown 3.3.6
msgpack 1.0.3
multidict 6.0.2
numpy 1.19.5
nvidia-ml-py3 7.352.0
oauthlib 3.2.0
opencensus 0.8.0
opencensus-context 0.1.2
opt-einsum 3.3.0
optax 0.0.9
packaging 21.3
pathy 0.6.1
pip 21.2.4
prometheus-client 0.13.1
protobuf 3.19.4
psutil 5.9.0
py-spy 0.3.11
pyasn1 0.4.8
pyasn1-modules 0.2.8
pydantic 1.9.0
pyparsing 3.0.7
pyrsistent 0.18.1
pytz 2022.1
PyYAML 6.0
ray 1.4.1
redis 4.1.4
regex 2022.3.15
requests 2.27.1
requests-oauthlib 1.3.1
rsa 4.8
sacremoses 0.0.49
scipy 1.8.0
setuptools 58.0.4
six 1.15.0
smart-open 5.2.1
tabulate 0.8.9
tensorboard 2.6.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow-cpu 2.6.3
tensorflow-estimator 2.6.0
termcolor 1.1.0
tokenizers 0.11.6
toolz 0.11.2
torch 1.11.0
tqdm 4.63.0
transformers 4.17.0
typer 0.4.0
typing-extensions 3.10.0.2
urllib3 1.26.9
Werkzeug 2.0.3
wheel 0.37.1
wrapt 1.12.1
yarl 1.7.2
zipp 3.7.0
and
$ python3 --version
Python 3.8.1
On: MacOS Catalina (v10.15.7)
Update: it worked after I switched to a Linux machine.