google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot do inference in float32

borisdayma opened this issue · comments

If we try to perform inference in float32, we get the error:

AssertionError: Key and Value Dtypes should match

This error comes from this line.

The origin of the error is that the cache dtype is set to jnp.int8 if quantize_kvcache else jnp.bfloat16 but never to jnp.float32.

What are you setting that triggets this? (Activations to float32?)

Yes it's the dtype:

dtype: "bfloat16"