snap-stanford / csr

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BUG report about graph_sampler

lyyf2002 opened this issue · comments

When I try to resample the inductive data of FB15K, I found that if I just run the following code in 'graph_sampler.py', the result data is not the same as the data you published in Google Driver.:

    hop = 1
    generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['pretrain','dev','test'], kind = "union_prune_plus", hop=hop, sample_neg = False, inductive = True)
    generate_subgraph_datasets(".", dataset="FB15K-237", splits = ['dev','test'], kind = "union_prune_plus", hop=hop, all_negs = True, sample_all_negs = False, inductive = True)
    
    SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="pretrain", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = False, inductive = True)
    SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="dev", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = True, inductive = True)
    SubgraphFewshotDataset(".", shot = 3, dataset="FB15K-237", mode="test", kind="union_prune_plus", hop=hop, preprocess = True, preprocess_50neg = True, inductive = True)

And now I find why:
When set 'inductive = True', your code will use a postfix, which leads to the loaded background graph being 'graph_inductive.pt' and 'path_graph_indutive.json', but in the paper, the dev/test background graph is not inductive.
I have solved this problem and will create a pull request as soon as possible.

What is more,the 'load_kg_dataset.py' line 135-142 and 148-153 should not run in preprocess:

        if mode == "test" and inductive:
            print("subsample tasks!!!!!!!!!!!!!!!!!!!")
            self.test_tasks_idx = json.load(open(os.path.join(raw_data_paths,  'sample_test_tasks_idx.json')))
            for r in list(self.tasks.keys()):
                if r not in self.test_tasks_idx:
                    self.tasks[r] = []
                else:
                    self.tasks[r] = np.array(self.tasks[r])[self.test_tasks_idx[r]].tolist()
        if mode == "test" and inductive:
            for idx, r in enumerate(list(self.all_rels)):
                if len(self.tasks[r]) == 0:
                    del self.tasks[r]
                    print("remove empty tasks!!!!!!!!!!!!!!!!!!!")
            self.all_rels = sorted(list(self.tasks.keys()))