tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax.dtypes.prng_key gives `AttributeError: module 'jax.dtypes' has no attribute 'prng_key'`

Joshuaalbert opened this issue · comments

Receiving this:

value = 0.9189385332046727, dtype = <class 'numpy.float32'>

    def _default_convert_to_tensor(value, dtype=None):
      """Default tensor conversion function for array, bool, int, float, and complex."""
      if JAX_MODE:
        # TODO(b/223267515): We shouldn't need to specialize here.
        if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
>           value.dtype, jax.dtypes.prng_key
        ):
E       AttributeError: module 'jax.dtypes' has no attribute 'prng_key'
jax==0.4.8
jaxlib==0.4.7
tensorflow-probability==0.22.0

Upgrading to below resolved the problem

jax==0.4.26
jaxlib==0.4.26
tensorflow-probability==0.24.0