tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tf.vectorized_map not compatible with LinearGaussianStateSpaceModel forward_filter?

Qiustander opened this issue · comments

Hi all, I have a batch of data and want to conduct the Kalman filter for each observation set. My data is observations =(batch, num_time_lens, feature_dim) so I use tf.vectorized_map for parallel computation. I used TFP's official implementation of Kalman filter.

However, the tf.vectorized_map works fine with tf.function wrapped version of Kalman filter but fails when XLA compile is used. The reproducible example is (from TFP docstrings):

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
ndims = 2
step_std = 1.0
noise_std = 5.0
model = tfd.LinearGaussianStateSpaceModel(
    num_timesteps=100,
    transition_matrix=tf.linalg.LinearOperatorIdentity(ndims),
    transition_noise=tfd.MultivariateNormalDiag(
        scale_diag=step_std**2 * tf.ones([ndims])),
    observation_matrix=tf.linalg.LinearOperatorIdentity(ndims),
    observation_noise=tfd.MultivariateNormalDiag(
        scale_diag=noise_std**2 * tf.ones([ndims])),
    initial_state_prior=tfd.MultivariateNormalDiag(
        scale_diag=tf.ones([ndims])))

"""
Generate data 
"""
x = model.sample(10) # Sample from the prior on sequences of observations.

def kalman_filter_wrapper(input):
    _, filtered_means, filtered_covs, _, _, _, _ = model.forward_filter(input)
    return filtered_means

@tf.function(jit_compile=True)
def run_sim():
    means = tf.vectorized_map(kalman_filter_wrapper, x)
    return means

d = run_sim()

The error is:

2023-11-06 09:30:27.757506: W tensorflow/core/framework/op_kernel.cc:1828] OP_REQUIRES failed at xla_ops.cc:503 : INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_run_sim_14763[_XlaMustCompile=true,config_proto=3175580994766145631,executor_type=11160318154034397263] on XLA_CPU_JIT: TensorListReserve (No registered 'TensorListReserve' OpKernel for XLA_CPU_JIT devices compatible with node {{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}
	 (OpKernel was found, but attributes didn't match) Requested Attributes: element_dtype=DT_VARIANT, shape_type=DT_INT32){{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}

What is TensorListReserve operation? Is there any work-around method? Thanks

Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.

Hi thanks for your reply. It did work for batched input. But I am curious about TensorListReverse operation which does not exist in the source code. Could you answer this question? Thanks