jax.dtypes.prng_key gives `AttributeError: module 'jax.dtypes' has no attribute 'prng_key'`
Joshuaalbert opened this issue · comments
Joshua George Albert commented
Receiving this:
value = 0.9189385332046727, dtype = <class 'numpy.float32'>
def _default_convert_to_tensor(value, dtype=None):
"""Default tensor conversion function for array, bool, int, float, and complex."""
if JAX_MODE:
# TODO(b/223267515): We shouldn't need to specialize here.
if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
> value.dtype, jax.dtypes.prng_key
):
E AttributeError: module 'jax.dtypes' has no attribute 'prng_key'
jax==0.4.8
jaxlib==0.4.7
tensorflow-probability==0.22.0
Joshua George Albert commented
Upgrading to below resolved the problem
jax==0.4.26
jaxlib==0.4.26
tensorflow-probability==0.24.0