kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tpu_driver0.1 is not initialized on colab (cannot infer with GPT-J on Colab) [Again]

mosmos6 opened this issue · comments

All versions of tpu_driver0.1 have been disabled on colab since yesterday, so all derivatives from GPT-J, including GPT-J, are unloadable. Because tpu_driver0.1 is unavailable, it results in exceeding deadline to connect to gRPC.
We are discussing this issue here but I'm posting this here
because I think GPT-J users might start looking for issues from here.

tpu_driver_20221011 is accessible but it crashes GPT-J and its derivatives. Currently tpu_driver0.1 is still available on kaggle for some reason.
If anyone knows a solution, please share it here or the colab issue.

Screenshot 2023-02-12 084246

I would like to add that it has not been explicitly disabled, at least not yet. The driver does get registered and does not return it is invalid, the TPU just fails to initialize. This seems to apply to all 0.1 drivers. 0.2 and higher do initialize.

I recommend we give google some time during office hours to respond or fix this. If they do not intend to fix it MTJ will have to be updated to work on newer drivers.

@henk717 You're right. I modified the title. I appreciate the model runs on kaggle, but 20 hours per week is a little tight.
Besides, it's unusually slow to download dependencies and whole colab feels a little funny since Saturday.
As I'm paying for colab pro subscription, I hope this will be resolved soon.. It's already Monday.

Could you tell, is there any plans to update mesh-transformer-jax to work with driver0.2 on Colab?

At this moment, the issue seems resolved and GPT-J can be properly loaded.

The same issue has occurred again so reopening up this post.

Facing the same issue in running:

RuntimeError: Deadline exceeded: Failed to connect to remote server at address: grpc://10.89.22.138:8470. Error from gRPC: Deadline Exceeded.

I modified mesh-transformer-jax to adapt to JAX 0.3.25 and TPU_driver0.2 here. #256
so I'm going to close this issue.