Create Contrastive Sampler / Utilities for constructing datasets.
owenvallis opened this issue · comments
Owen Vallis commented
Constructing datasets for the contrastive model is currently adhoc and could benefit from some utility funcs. For example.
# Compute the indicies for query, index, val, and train splits
query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []
for cid in range(ds_info.features["label"].num_classes):
idxs = tf.random.shuffle(tf.where(y_raw_train == cid))
idxs = tf.reshape(idxs, (-1,))
query_idxs.extend(idxs[:100]) # 200 query examples per class
index_idxs.extend(idxs[100:200]) # 200 index examples per class
val_idxs.extend(idxs[200:300]) # 100 validation examples per class
train_idxs.extend(idxs[300:]) # The remaining are used for training
random.shuffle(query_idxs)
random.shuffle(index_idxs)
random.shuffle(val_idxs)
random.shuffle(train_idxs)
def create_split(idxs: list) -> tuple:
x, y = [], []
for idx in idxs:
x.append(x_raw_train[int(idx)])
y.append(y_raw_train[int(idx)])
return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(
np.array(y), dtype=tf.int64
)
x_query, y_query = create_split(query_idxs)
x_index, y_index = create_split(index_idxs)
x_val, y_val = create_split(val_idxs)
x_train, y_train = create_split(train_idxs)