tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sample from a partially known TensorShape inside the train_step function of a keras subclassed model

claCase opened this issue · comments

I'm trying to sample a random vector from a tfd distribution. When trying to call the function inside the train_step called from the fit function of a keras subclassed model the batch shape gets automatically removed and defined as None. Is it possible to get around this problem? I see that in the docs of the Distribution class the shape of the sample must be statically known. The issue is not present using tf.keras.backend.random_normal function. It think this is a similar issue to #425, where keras.Input is used.

You can find this sample code to replicate the issue.

import tensorflow as tf 
import tensorflow_probability.python.distributions as tfd 

class cl(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        super().build(input_shape)
        self.normal = tfd.Normal(loc=[1.] * input_shape[-1], scale=[1.] * input_shape[-1])

    def call(self, inputs, *args):
        batch_shape = inputs.shape[0]
        return self.normal.sample(tf.TensorShape((batch_shape, 2)))


class cm(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.normal = cl()

    def call(self, inputs, *args):
        print(f"tf.shape: {tf.shape(inputs)}, inputs.shape:{inputs.shape}")
        sample = self.normal(inputs)
        return sample

    def train_step(self, data):
        B = tf.shape(data)[0]
        B2 = data.shape[0]
        rnd = tf.keras.backend.random_normal(shape=(B, 2))
        print(f"random normal from tf.keras.backend shape: {rnd.shape}")
        ta = tf.TensorArray(dtype=tf.float32, size=100, element_shape=tf.TensorShape((B2, 1)))
        sample = self(data)  # Error
        return {"loss":1}

inputs = tf.random.normal(shape=(100, 10, 2))
c0 = cm()
c0.compile("adam")
c0.fit(tf.constant(inputs),epochs=1)

which outputs:

random normal from tf.keras.backend shape: (None, 2)
tf.shape: Tensor("cm_35/Shape:0", shape=(3,), dtype=int32), inputs.shape:(None, 10, 2)

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 2)

I solved the issue by passing the batch_size parameter to the fit function

c0.fit(tf.constant(inputs),epochs=1, batch_size=10)