kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error while `to_hf_weights.py`: `ValueError: cannot reshape array of size 25804800 into shape (1,4096,50400)`

danyaljj opened this issue · comments

$ python3 to_hf_weights.py --input-ckpt gs://danielk-files/gpt-j-checkpoints_slim/step_318500 --config configs/6B_roto_256.json --output-path gs://danielk-files/danielk-files/gpt-j-checkpoints_slim_hf/step_318500 --dtype fp32
to_hf_weights.py:101: UserWarning: WARNING: Dtype support other than fp16 is Experimental. Make sure to check weights after conversion to make sure dtype information is retained.
  warnings.warn(
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Total parameters: 6050886880
Reading and transforming layers/shards. This may take a while.
Reading/Transforming Layers:   0%|▋                                                                                                                                                                                   | 1/287 [00:26<2:04:31, 26.12s/it]to_hf_weights.py:384: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at  /Users/distiller/project/pytorch/aten/src/ATen/native/TensorShape.cpp:2318.)
  x = torch.tensor(x.squeeze(0), dtype=torch_dtype).T
Reading/Transforming Layers:   1%|█▉                                                                                                                                                                                    | 3/287 [00:26<41:19,  8.73s/it]
Traceback (most recent call last):
  File "to_hf_weights.py", line 488, in <module>
    save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
  File "to_hf_weights.py", line 466, in save_sharded_to_hf_format
    save_pytree_as_hf(
  File "to_hf_weights.py", line 382, in save_pytree_as_hf
    x = unshard_leave(x, leave_name, old_shape, np_dtype=np_dtype)
  File "to_hf_weights.py", line 324, in unshard_leave
    x = reshard(
  File "to_hf_weights.py", line 244, in reshard
    out = np.reshape(x, old_shape)
  File "<__array_function__ internals>", line 5, in reshape
  File "/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 299, in reshape
    return _wrapfunc(a, 'reshape', newshape, order=order)
  File "/Users/danielk/opt/anaconda3/envs/jax_py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 58, in _wrapfunc
    return bound(*args, **kwds)
ValueError: cannot reshape array of size 25804800 into shape (1,4096,50400)

Here is my environment:

$ pip3 list 
Package                  Version
------------------------ ---------
absl-py                  0.15.0
aiohttp                  3.8.1
aiohttp-cors             0.7.0
aioredis                 2.0.1
aiosignal                1.2.0
astunparse               1.6.3
async-timeout            4.0.2
attrs                    21.4.0
blessings                1.7
cachetools               4.2.4
certifi                  2021.10.8
charset-normalizer       2.0.12
chex                     0.1.1
clang                    5.0
click                    8.0.4
cloudpickle              1.3.0
colorama                 0.4.4
colorful                 0.5.4
Deprecated               1.2.13
dm-haiku                 0.0.5
dm-tree                  0.1.6
einops                   0.3.2
filelock                 3.6.0
flatbuffers              1.12
frozenlist               1.3.0
gast                     0.4.0
google-api-core          1.31.5
google-auth              1.35.0
google-auth-oauthlib     0.4.6
google-cloud-core        1.7.2
google-cloud-storage     1.36.2
google-crc32c            1.3.0
google-pasta             0.2.0
google-resumable-media   1.3.3
googleapis-common-protos 1.56.0
gpustat                  0.6.0
grpcio                   1.44.0
h5py                     3.1.0
huggingface-hub          0.4.0
idna                     3.3
importlib-metadata       4.11.3
importlib-resources      5.4.0
jax                      0.2.12
jaxlib                   0.1.68
jmp                      0.0.2
joblib                   1.1.0
jsonschema               4.4.0
keras                    2.6.0
Keras-Preprocessing      1.1.2
Markdown                 3.3.6
msgpack                  1.0.3
multidict                6.0.2
numpy                    1.19.5
nvidia-ml-py3            7.352.0
oauthlib                 3.2.0
opencensus               0.8.0
opencensus-context       0.1.2
opt-einsum               3.3.0
optax                    0.0.9
packaging                21.3
pathy                    0.6.1
pip                      21.2.4
prometheus-client        0.13.1
protobuf                 3.19.4
psutil                   5.9.0
py-spy                   0.3.11
pyasn1                   0.4.8
pyasn1-modules           0.2.8
pydantic                 1.9.0
pyparsing                3.0.7
pyrsistent               0.18.1
pytz                     2022.1
PyYAML                   6.0
ray                      1.4.1
redis                    4.1.4
regex                    2022.3.15
requests                 2.27.1
requests-oauthlib        1.3.1
rsa                      4.8
sacremoses               0.0.49
scipy                    1.8.0
setuptools               58.0.4
six                      1.15.0
smart-open               5.2.1
tabulate                 0.8.9
tensorboard              2.6.0
tensorboard-data-server  0.6.1
tensorboard-plugin-wit   1.8.1
tensorflow-cpu           2.6.3
tensorflow-estimator     2.6.0
termcolor                1.1.0
tokenizers               0.11.6
toolz                    0.11.2
torch                    1.11.0
tqdm                     4.63.0
transformers             4.17.0
typer                    0.4.0
typing-extensions        3.10.0.2
urllib3                  1.26.9
Werkzeug                 2.0.3
wheel                    0.37.1
wrapt                    1.12.1
yarl                     1.7.2
zipp                     3.7.0

and

$ python3 --version
Python 3.8.1

On: MacOS Catalina (v10.15.7)

Update: it worked after I switched to a Linux machine.