vgsatorras / few-shot-gnn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

get_task_batch function from Generator

ranyaphat29 opened this issue · comments

for batch_counter in range(batch_size):
positive_class = random.randint(0, n_way - 1)

        # Sample random classes for this TASK
        classes_ = list(self.data.keys())
        sampled_classes = random.sample(classes_, n_way)
        indexes_perm = np.random.permutation(n_way * num_shots)

        counter = 0
        for class_counter, class_ in enumerate(sampled_classes):
            if class_counter == positive_class:
                # We take num_shots + one sample for one class
                samples = random.sample(self.data[class_], num_shots+1)
                # Test sample is loaded
                batch_x[batch_counter, :, :, :] = samples[0]
                labels_x[batch_counter, class_counter] = 1
                labels_x_global[batch_counter] = self.class_encoder[class_]
                samples = samples[1::]
            else:
                samples = random.sample(self.data[class_], num_shots)

Why we need positive class ? What is the meaning of indexes perm ?and why we set labels_x[batch_counter, class_counter] = 1 by class_counter instead of using class_?

Thanks,