tensorflow / lattice

Lattice methods in TensorFlow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Many-batches predictions

matibilkis opened this issue · comments

Hi,

When trying to get predictions of Lattice Models on more than one batch of data at once, Errors are raised. This is a nice feature to efficiently get predictions, and is present in basic Neural Network Keras models;
find some examples in this colab.

As far as I can tell from looking at API docs + source code, this should be related to the inputs admitted by PWC layers, but I wonder if there is an easy way around.

In particular, this piece of code captures what I would like to get (and retrieves an error when calling on batched_inputs):


class LatticeModel(tf.keras.Model):
    def __init__(self, nodes=[2,2], nkeypoints=100):
        super(LatticeModel,self).__init__()
        self.combined_calibrators = tfl.layers.ParallelCombination()
        for ind,i in enumerate(range(2)):
          calibration_layer = tfl.layers.PWLCalibration(input_keypoints=np.linspace(0,1,nkeypoints),output_min=0.0, output_max=nodes[ind])
          self.combined_calibrators.append(calibration_layer)
        self.lattice = tfl.layers.Lattice(lattice_sizes=nodes,interpolation="simplex")
        
    def call(self, x):
        rescaled = self.combined_calibrators(x)
        feat = self.lattice(rescaled)
        return feat
    
#we define some input data
x1 = np.random.randn(100,1).astype(np.float32)
x2 = np.random.randn(100,1).astype(np.float32)

inputs = tf.concat([x1,x2], axis=-1)

#we initialize out model, and feed it with a batch of size 100
model = LatticeModel()
model(inputs)

### now we would like to efficiently predict the output of the lattice model on many batches of data at once (in this case 2)
batched_inputs = np.random.randn(2,100,1)
model(batched_inputs)

Thanks a lot!
Matías.

The model you have constructed here expects an input of shape (B, 2). The input you are passing in your last call is (2, 100, 1). Maybe you meant to pass (2, 100, 2), which would be 2 batches, each of shape (100, 2). A general way of approaching this is to use tf.reshape

batched_inputs = np.random.randn(num_batches, batch_size, input_dim)
reshaped_batched_inputs = tf.reshape([-1, input_dim])  # would be of shape (num_batches * batch_size, input_dim)
flat_preds = model(batched_inputs)  # would be of shape (num_batches * batch_size, 1)
preds = tf.reshape([num_batches, batch_size])

You can do the reshaping inside the model before and after the call to the layers.

cool, many thanks!