google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unexpected dtype returned from jnp.outer with mixed inputs dtypes

BodeTobias opened this issue · comments

Description

I have observed an unexpected output-dtype of the jax.numpy.outer function when the two inputs have int and uint-dtypes. I would expect getting an int back.

Code

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

left = jnp.ones(3, dtype=jnp.uint)
right = jnp.ones(5, dtype=jnp.int_)
out = jnp.outer(left, right)
print('left: ', left.dtype)
print('right: ', right.dtype)
print('out: ', out.dtype)

left = jnp.ones(3, dtype=jnp.int_)
right = jnp.ones(5, dtype=jnp.int_)
out = jnp.outer(left, right)
print('left: ', left.dtype)
print('right: ', right.dtype)
print('out: ', out.dtype)

Output

left: uint64
right: int64
out: float64
left: int64
right: int64
out: int64

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.24.3
python: 3.11.9 (tags/v3.11.9:de54cf5, Apr 2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='levy', release='10', version='10.0.22631', machine='AMD64')

Sorry for that. Just found that it is actually intended for the 64-bit mixture (https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html).