kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

to_hf_weights script returns "Failed to allocate" error

guidomeijer opened this issue · comments

I compressed the weights of a GPT-J-6B model using the slim_model script and am now trying to convert them into pycharm hugging face format. I'm running this on a v3-8 TPU and so far everything worked fine. However, when running the to_hf_weights script I get an out of memory error which I find quite odd since the weights are already compressed and I'm running on a big TPU.

key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
2021-12-23 18:14:57.707572: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1981] Execution of replica 0 failed: Resource exhausted: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Traceback (most recent call last):
  File "to_hf_weights.py", line 488, in <module>
    save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
  File "to_hf_weights.py", line 464, in save_sharded_to_hf_format
    network = CausalTransformer(params_local)
  File "/home/guido/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 277, in __init__
    self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
  File "/home/guido/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 516, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/guido/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 652, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/guido/.local/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/guido/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 655, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/guido/.local/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/guido/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
    return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
  File "/home/guido/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1130, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Resource exhausted: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0

jax==0.2.12
jaxlib==0.1.67 (I tried 0.1.68 but getting errors)
Any help would be greatly appreciated :)

Hey there I had the same problem today, also using a v3-8 TPU vm instance . Maybe not the most satisfying solution, but I did solve this by passing the --cpu flag. My finetuned model completed in 6 min when using the cpu.

Not entirely sure why cpu is not default, but afaik there is not any particular reason to not do the conversion on GPU/TPU