BUG report about graph_sampler
lyyf2002 opened this issue · comments
Luo Yangyifei commented
When I try to resample the inductive data of FB15K, I found that if I just run the following code in 'graph_sampler.py', the result data is not the same as the data you published in Google Driver.:
hop = 1
generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['pretrain','dev','test'], kind = "union_prune_plus", hop=hop, sample_neg = False, inductive = True)
generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['dev','test'], kind = "union_prune_plus", hop=hop, all_negs = True, sample_all_negs = False, inductive = True)
SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="pretrain", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = False, inductive = True)
SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="dev", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = True, inductive = True)
SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="test", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = True, inductive = True)
And now I find why:
When set 'inductive = True', your code will use a postfix, which leads to the loaded background graph being 'graph_inductive.pt' and 'path_graph_indutive.json', but in the paper, the dev/test background graph is not inductive.
I have solved this problem and will create a pull request as soon as possible.
Luo Yangyifei commented
What is more,the 'load_kg_dataset.py' line 135-142 and 148-153 should not run in preprocess:
if mode == "test" and inductive:
print("subsample tasks!!!!!!!!!!!!!!!!!!!")
self.test_tasks_idx = json.load(open(os.path.join(raw_data_paths, 'sample_test_tasks_idx.json')))
for r in list(self.tasks.keys()):
if r not in self.test_tasks_idx:
self.tasks[r] = []
else:
self.tasks[r] = np.array(self.tasks[r])[self.test_tasks_idx[r]].tolist()
if mode == "test" and inductive:
for idx, r in enumerate(list(self.all_rels)):
if len(self.tasks[r]) == 0:
del self.tasks[r]
print("remove empty tasks!!!!!!!!!!!!!!!!!!!")
self.all_rels = sorted(list(self.tasks.keys()))