Unexpected dtype returned from jnp.outer with mixed inputs dtypes
BodeTobias opened this issue · comments
Description
I have observed an unexpected output-dtype of the jax.numpy.outer function when the two inputs have int and uint-dtypes. I would expect getting an int back.
Code
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
left = jnp.ones(3, dtype=jnp.uint)
right = jnp.ones(5, dtype=jnp.int_)
out = jnp.outer(left, right)
print('left: ', left.dtype)
print('right: ', right.dtype)
print('out: ', out.dtype)
left = jnp.ones(3, dtype=jnp.int_)
right = jnp.ones(5, dtype=jnp.int_)
out = jnp.outer(left, right)
print('left: ', left.dtype)
print('right: ', right.dtype)
print('out: ', out.dtype)
Output
left: uint64
right: int64
out: float64
left: int64
right: int64
out: int64
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.24.3
python: 3.11.9 (tags/v3.11.9:de54cf5, Apr 2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='levy', release='10', version='10.0.22631', machine='AMD64')
Sorry for that. Just found that it is actually intended for the 64-bit mixture (https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html).