kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The latest update to PIP breaks installation

markriedl opened this issue · comments

The latest update of PIP seems to have changed dependency resolution such that mesh-transformer-jax hard to install. Installation now requires an older version of pip, some specific versions of packages like 'transformers' and some forced ordering of packages.

These are the steps I had to do:

  1. pip install pip==22.0
  2. Edit requirements.txt:
  • transformers==4.16.2
  • fastapi==0.73.0
  • uvicorn==0.17.1
  1. pip install jax==0.2.12 tensorflow==2.5.0 (as before but has to come earlier)
  2. pip install -r mesh-transformer-jax/requirements.txt
  3. pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

For me too, the installation of the requirements takes forever.

Same here, even when following @markriedl's steps, pip still searches for hours. I know these days it would be more straightforward to just use the model at goose.ai, but I wanted to make some experiments with sampling etc. It'd be great if it could be made to work again!

I wonder if I didn't transcribe the order of operations right. This notebook won't take too long to run: https://colab.research.google.com/drive/17zvUhLcpjUKJdTRg00HYdGMEN3uoMy-M?usp=sharing

Thanks! It works again.

I wonder if I didn't transcribe the order of operations right. This notebook won't take too long to run: https://colab.research.google.com/drive/17zvUhLcpjUKJdTRg00HYdGMEN3uoMy-M?usp=sharing

Thanks for this fix. I wonder if I'm the only one for whom network.state = read_ckpt(network.state, "step_383500/", devices.shape[1]) is causing a RAM overflow terminating the session?

I'm getting AttributeError: module 'jax.random' has no attribute 'KeyArray' on import optax with your example @markriedl

Sorry @markriedl I see the solution posted in #221