tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Create Contrastive Sampler / Utilities for constructing datasets.

owenvallis opened this issue · comments

Constructing datasets for the contrastive model is currently adhoc and could benefit from some utility funcs. For example.

# Compute the indicies for query, index, val, and train splits
query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []

for cid in range(ds_info.features["label"].num_classes):
    idxs = tf.random.shuffle(tf.where(y_raw_train == cid))
    idxs = tf.reshape(idxs, (-1,))
    query_idxs.extend(idxs[:100])  # 200 query examples per class
    index_idxs.extend(idxs[100:200])  # 200 index examples per class
    val_idxs.extend(idxs[200:300])  # 100 validation examples per class
    train_idxs.extend(idxs[300:])  # The remaining are used for training


random.shuffle(query_idxs)
random.shuffle(index_idxs)
random.shuffle(val_idxs)
random.shuffle(train_idxs)


def create_split(idxs: list) -> tuple:
    x, y = [], []
    for idx in idxs:
        x.append(x_raw_train[int(idx)])
        y.append(y_raw_train[int(idx)])
    return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(
        np.array(y), dtype=tf.int64
    )


x_query, y_query = create_split(query_idxs)
x_index, y_index = create_split(index_idxs)
x_val, y_val = create_split(val_idxs)
x_train, y_train = create_split(train_idxs)