convert to tensor in index prevents passing a dict of inputs
owenvallis opened this issue · comments
ValueError Traceback (most recent call last)
<ipython-input-19-d4b57d7b51eb> in <module>()
1 index_data = [{"prompt_input": p, "response_input": r} for p, r in zip(prompt_index, response_index)]
----> 2 model.index({"prompt_input": prompt_index, "response_input": response_index}, y=y_index, data=index_data)
2 frames
google3/third_party/py/tensorflow_similarity/models/similarity_model.py in index(self, x, y, data, build, verbose)
350 print("|-Computing embeddings")
351 with tf.device("/cpu:0"):
--> 352 x = tf.convert_to_tensor(np.array(x))
353
354 predictions = self.predict(x)
this change was required to prevent a slowdown and possible memory leak when passing lists of inputs instead of np.array or tensors. However, this breaks passing multiple inputs.
We should add a type check first and handle the multi input case properly.
Removing all tf.convert_to_tensor() calls before predict. While the previous change prevented the memory leak in the case where we called multiple models in a loop, it ended up restricting calls to predict to a single tensor batch. This is too restrictive and prevents us from calling multi-headed models.