AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Create a user friendly inference demo

borisdayma opened this issue · comments

This is a feature request.

I like maxtext because it is very customizable and efficient for training.
The main issue I’m having is hacking away an inference function. The code is quite complex so not straightforward to do.
The simple decode.py works but it seems mainly experimental development for streaming.

I think streaming will be really cool, but we would also benefit from an easy model.generate(input_ids, attention_mask, params) function:

  • it should allow prefill based on the length of input_ids (user responsibility to try to supply not too many shapes to avoid recompilation)
  • it should allow batch input, with left padding to support different input length
  • should be compilable with jit/pjit
  • allow a few common sampling strategy: greedy, sample (with temperature, top k, top p), beam search
  • allow being used without a separate engine/service in case we want to make it part of a larger function that includes multiple models

This PR looked interesting: #402
I think that it was mainly for benchmarking though as it didn’t stop when the entire batch was eos but had a nice prefill functionality.