Sample from a partially known TensorShape inside the train_step function of a keras subclassed model
claCase opened this issue · comments
I'm trying to sample a random vector from a tfd distribution. When trying to call the function inside the train_step called from the fit function of a keras subclassed model the batch shape gets automatically removed and defined as None. Is it possible to get around this problem? I see that in the docs of the Distribution class the shape of the sample must be statically known. The issue is not present using tf.keras.backend.random_normal function. It think this is a similar issue to #425, where keras.Input is used.
You can find this sample code to replicate the issue.
import tensorflow as tf
import tensorflow_probability.python.distributions as tfd
class cl(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def build(self, input_shape):
super().build(input_shape)
self.normal = tfd.Normal(loc=[1.] * input_shape[-1], scale=[1.] * input_shape[-1])
def call(self, inputs, *args):
batch_shape = inputs.shape[0]
return self.normal.sample(tf.TensorShape((batch_shape, 2)))
class cm(tf.keras.models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.normal = cl()
def call(self, inputs, *args):
print(f"tf.shape: {tf.shape(inputs)}, inputs.shape:{inputs.shape}")
sample = self.normal(inputs)
return sample
def train_step(self, data):
B = tf.shape(data)[0]
B2 = data.shape[0]
rnd = tf.keras.backend.random_normal(shape=(B, 2))
print(f"random normal from tf.keras.backend shape: {rnd.shape}")
ta = tf.TensorArray(dtype=tf.float32, size=100, element_shape=tf.TensorShape((B2, 1)))
sample = self(data) # Error
return {"loss":1}
inputs = tf.random.normal(shape=(100, 10, 2))
c0 = cm()
c0.compile("adam")
c0.fit(tf.constant(inputs),epochs=1)
which outputs:
random normal from tf.keras.backend shape: (None, 2)
tf.shape: Tensor("cm_35/Shape:0", shape=(3,), dtype=int32), inputs.shape:(None, 10, 2)
ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 2)
I solved the issue by passing the batch_size parameter to the fit function
c0.fit(tf.constant(inputs),epochs=1, batch_size=10)