Cannot do inference in float32
borisdayma opened this issue · comments
Boris Dayma commented
If we try to perform inference in float32, we get the error:
AssertionError: Key and Value Dtypes should match
This error comes from this line.
The origin of the error is that the cache dtype
is set to jnp.int8 if quantize_kvcache else jnp.bfloat16
but never to jnp.float32
.
Rafi Witten commented
What are you setting that triggets this? (Activations to float32?)
Boris Dayma commented
Yes it's the dtype:
maxtext/MaxText/configs/base.yml
Line 61 in f52e6f7