tf.vectorized_map not compatible with LinearGaussianStateSpaceModel forward_filter?
Qiustander opened this issue · comments
Hi all, I have a batch of data and want to conduct the Kalman filter for each observation set. My data is observations =(batch, num_time_lens, feature_dim)
so I use tf.vectorized_map
for parallel computation. I used TFP's official implementation of Kalman filter.
However, the tf.vectorized_map works fine with tf.function
wrapped version of Kalman filter but fails when XLA compile is used. The reproducible example is (from TFP docstrings):
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
ndims = 2
step_std = 1.0
noise_std = 5.0
model = tfd.LinearGaussianStateSpaceModel(
num_timesteps=100,
transition_matrix=tf.linalg.LinearOperatorIdentity(ndims),
transition_noise=tfd.MultivariateNormalDiag(
scale_diag=step_std**2 * tf.ones([ndims])),
observation_matrix=tf.linalg.LinearOperatorIdentity(ndims),
observation_noise=tfd.MultivariateNormalDiag(
scale_diag=noise_std**2 * tf.ones([ndims])),
initial_state_prior=tfd.MultivariateNormalDiag(
scale_diag=tf.ones([ndims])))
"""
Generate data
"""
x = model.sample(10) # Sample from the prior on sequences of observations.
def kalman_filter_wrapper(input):
_, filtered_means, filtered_covs, _, _, _, _ = model.forward_filter(input)
return filtered_means
@tf.function(jit_compile=True)
def run_sim():
means = tf.vectorized_map(kalman_filter_wrapper, x)
return means
d = run_sim()
The error is:
2023-11-06 09:30:27.757506: W tensorflow/core/framework/op_kernel.cc:1828] OP_REQUIRES failed at xla_ops.cc:503 : INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_run_sim_14763[_XlaMustCompile=true,config_proto=3175580994766145631,executor_type=11160318154034397263] on XLA_CPU_JIT: TensorListReserve (No registered 'TensorListReserve' OpKernel for XLA_CPU_JIT devices compatible with node {{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}
(OpKernel was found, but attributes didn't match) Requested Attributes: element_dtype=DT_VARIANT, shape_type=DT_INT32){{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}
What is TensorListReserve
operation? Is there any work-around method? Thanks
Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.
Hi thanks for your reply. It did work for batched input. But I am curious about TensorListReverse
operation which does not exist in the source code. Could you answer this question? Thanks