tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Can't jit PoissonLogNormalQuadratureCompound log_prob

GianmarcoCallegher opened this issue · comments

If I try to jit the log_prob method of the PoissonLogNormalQuadratureCompound

from jax import jit
import tensorflow_probability.substrates.jax.distributions as tfd

jit(tfd.PoissonLogNormalQuadratureCompound(0.0, 1.0).log_prob)(1.)

I get the following error:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got
Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>.

operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)