Bug in dataloader ?
Brandonnogithub opened this issue · comments
Hi guys, I am trying to reproducing your work. In the dataloader, I found this code:
for sample_idx in range(self.num_sample):
for query_idx in range(len(self.query_examples)):
# If training, exclude the current example. Else keep all.
if self.use_demo and args.demo_filter:
# Demonstration filtering
candidate = [support_idx for support_idx in support_indices
if support_idx != query_idx or mode != "train"]
sim_score = []
for support_idx in candidate:
sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx])))
sim_score.sort(key=lambda x: x[1], reverse=True)
if self.num_labels == 1:
# Regression task
limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate)
count_each_label = {'0': 0, '1': 0}
context_indices = []
if args.debug_mode:
print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
for support_idx, score in sim_score:
if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label:
count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1
context_indices.append(support_idx)
if args.debug_mode:
print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
else:
limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate)
count_each_label = {label: 0 for label in self.label_list}
context_indices = []
if args.debug_mode:
print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
for support_idx, score in sim_score:
if count_each_label[self.support_examples[support_idx].label] < limit_each_label:
count_each_label[self.support_examples[support_idx].label] += 1
context_indices.append(support_idx)
if args.debug_mode:
print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
else:
# Using demonstrations without filtering
context_indices = [support_idx for support_idx in support_indices
if support_idx != query_idx or mode != "train"]
# We'll subsample context_indices further later.
self.example_idx.append((query_idx, context_indices, sample_idx))
Here it is calculating the similarity.
But I don't know why you use this loop: for sample_idx in range(self.num_sample)
at outermost, the sample_idx
is only used when you add the result into self.sample_idx
This codes is really slow, since you set the num_sample=16
I think you can remove for sample_idx in range(self.num_sample)
and change the last line as
for query_idx in range(len(self.query_examples)):
....
# We'll subsample context_indices further later.
for sample_idx in range(self.num_sample):
self.example_idx.append((query_idx, context_indices, sample_idx))
I don't know whether am I right.
In my test, I found after changing, the result is different.
Hi,
I think the two implementations are essentially equivalent. The difference is due to different orders having different random sampling results.
Hi,
I think the two implementations are essentially equivalent. The difference is due to different orders having different random sampling results.
Thanks for your reply.