kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Update the readme with required and recommended hardware list

sxiii opened this issue · comments

Greetings developers :)

This repo is very useful, but it would be even more useful if you could add the information about minimal requirements to hardware as well as recommended requirements. For instance, what if I use 32 GB VRAM single GPU? Would it work? Would it require tweaking? Or I'd better use 4x8 GB GPUs instead?

Sorry if this is a stupid question but I feel like minimal and recommended specs for both training and running these (I'm mostly interested in running) would be very useful.

P.S. Speaking here's about "GPT-J-6B".

Thanks!

I have been able to get it to run locally on an older 24GB Tesla P40 and also in Colab Pro with a TPU (v2). It's my first experience with JAX and so it was a bit (pun) of a learning experience.

I recently tried on just a beefy workstation 16-cores 96GB RAM and it was an excruciating wait for a response to a sentence prompt.

From the resharding example:

# This was tested with an RTX 3090, peak memory usage is approximately 22.4GB during inference, and 19GB when loading the model
# The following environment variables were also used: XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform```
Can confirm this works on my 3090 :)

Ok, so I suppose the answer would be "anything over 24 GB of VRAM" then, confirming both Nvidia RTX 3090 and Nvidia Tesla P40. Thank you guys :)