uta-smile / RetroXpert

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

KeyError: 'pattern_feat' when training the EGAT model

CHANG-Shaole opened this issue · comments

Firstly, thank you for sharing so exciting work!
After reading your paper, I tried to run the code following your demonstration.
After preparing the data & extract the semi-template pattern, (i.e. run the preprocessing.py & extract_semi_template_pattern.py) I got a KeyError when training the EGAT model. It is shown as follows,

Namespace(batch_size=32, dataset='USPTO50K', epochs=80, exp_name='USPTO50K_typed', gat_layers=3, heads=4, hidden_dim=128, in_dim=714, load=False, logdir='logs', lr=0.0005, seed=123, test_on_train=False, test_only=False, typed=True, use_cpu=False, valid_only=False)
Counter({1: 3482, 0: 1415, 2: 102, 9: 1, 17: 1})
Counter({1: 27851, 0: 11296, 2: 849, 3: 4, 4: 4, 10: 2, 7: 1, 13: 1})
  0%|                                                  | 0/1251 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 283, in <module>
    for i, data in enumerate(progress_bar):
  File "/home/changshaole/anaconda3/envs/retroxpert/lib/python3.6/site-packages/tqdm/_tqdm.py", line 1000, in __iter__
    for obj in iterable:
  File "/home/changshaole/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 346, in __next__
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/changshaole/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/changshaole/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/changshaole/Project/RetroXpert_0406/RetroXpert/data.py", line 44, in __getitem__
    x_pattern_feat = reaction_data['pattern_feat'].astype(np.float32)
KeyError: 'pattern_feat'

Do you have any idea about this error? I will be grateful for any advice.

same error

Found the issue: (For me)If you are generating your product_pattern.txt files for a new dataset, you need to set arg.extract_pattern = True and run extract_semi_template_pattern.py (Note there is a sys.exit() in the method find_all_patterns). After generating the product pattern, rerun extract_semi_template_pattern.py to generate 'pattern_feat' for reaction_data.

In the meantime, you might need to change the input dimension for GAT since it might result to a different dimension for the pattern_feat with a different training_set.