kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GPT-J inference on TPU

airesearch38 opened this issue · comments

Is it possible to use a TPU for inference?

The guys at NLPCloud.io told me that's what they're doing, but I have no idea how they're doing it...
First I don't know how to support advanced parameters like end_sequence (so the model stops generating when it reaches a specific token) or repetition penalty (see the Hugging face parameters for text generation).
Secondly, the TPU IPs seem to rotate on a regular basis and there's nothing you can do about it. So not sure how to use a TPU for inference through a REST API...

Thanks in advance!

You may consider running "device_serve.py" on TPU and the "streamlit" approach in the following.

https://github.com/vicgalle/gpt-j-api

Interesting, thanks for the suggestion!

If I understand correctly the code, stop_sequence is not stopping the model generation but simply splitting the result once the model finishes generating:

if stop_sequence is not None and stop_sequence in text:
        text = text.split(stop_sequence)[0] + stop_sequence

So generation takes the same time whether the stop_sequence token is reached or not.
Am I correct?

And I don't see a way to handle the fact that TPU IPs are regularly changing...

I was trying streamlit as a quick web app for testing model inference and found it convenient. Indeed, the floating IP of TPU is another issue. As for stop_sequence, I have no comment because I haven't encountered any issue with it yet. In brief, "device_serve.py" works on TPU. It could be a starting point.