uta-smile / RetroXpert

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: size mismatch

judory opened this issue · comments

Hi,

Thank you for sharing awesome codes.
but I'm getting some error when I tried 'python train.py --typed' like below
Please let me know how to fix it.

Traceback (most recent call last):
File "train.py", line 322, in
h_pred, e_pred = GAT_model(g_dgl, x_atom)
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/yim/RetroXpert/model/gat.py", line 121, in forward
h, _ = self.gat[l](g, h, merge='flatten')
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/yim/RetroXpert/model/gat.py", line 72, in forward
outs = list(map(lambda x: x(g, h), self.heads))
File "/home/yim/RetroXpert/model/gat.py", line 72, in
outs = list(map(lambda x: x(g, h), self.heads))
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/yim/RetroXpert/model/gat.py", line 45, in forward
g.ndata['h'] = self.embed_node(h)
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/home/yim/anaconda3/envs/retroxpert/lib/python3.6/site-packages/torch/nn/functional.py", line 1372, in linear
output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [860 x 584], m2: [714 x 128] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:290

commented

You need to modify the in dim according to your semi template.
parser.add_argument('--in_dim', type=int, default=47 + 530, help='dim of atom feature')
This 530 is the output after you execute 'python extract_semi_template_pattern.py'.