Can't jit PoissonLogNormalQuadratureCompound log_prob
GianmarcoCallegher opened this issue · comments
Gianmarco Callegher commented
If I try to jit
the log_prob
method of the PoissonLogNormalQuadratureCompound
from jax import jit
import tensorflow_probability.substrates.jax.distributions as tfd
jit(tfd.PoissonLogNormalQuadratureCompound(0.0, 1.0).log_prob)(1.)
I get the following error:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got
Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>.
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)