wengong-jin / hgraph2graph

Hierarchical Generation of Molecular Graphs using Structural Motifs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError in in the polymers/ folder

goznxn opened this issue · comments

Hi Wengong,

Thank you for sharing this wonderful work.

I try to use the code of polymers folder to generate molecule based on ZINK250K data. I got the data from the github of your junction tree folder.
I first used get_vocab.py to get vocab data and use preprocess.py to get train data.

When i run the code, i got the following error,
Namespace(anneal_rate=0.9, atom_vocab=<poly_hgraph.vocab.Vocab object at 0x2b5a6fb51e10>, batch_size=20, beta=0.3, clip_norm=20.0, depthG=20, depthT=20, diterG=5, diterT=1, dropout=0.0, embed_size=250, epoch=20, hidden_size=250, latent_size=24, load_epoch=-1, lr=0.001, print_iter=50, rnn_type='LSTM', save_dir='models/', save_iter=-1, train='train_processed/', vocab='zinc_vocab.txt')
/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/_reduction.py:46: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
Model #Params: 5742K
Traceback (most recent call last):
File "vae_train.py", line 80, in
loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/hgnn.py", line 76, in forward
loss, wacc, iacc, tacc, sacc = self.decoder((root_vecs, tree_vecs, graph_vecs), graphs, tensors, orders)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/decoder.py", line 254, in forward
topo_scores = self.get_topo_score(src_tree_vecs, batch_idx, topo_vecs)
File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/decoder.py", line 137, in get_topo_score
return self.topoNN( torch.cat([topo_vecs, topo_cxt], dim=-1) ).squeeze(-1)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 92, in forward
return F.linear(input, self.weight, self.bias)
File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/functional.py", line 1406, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [872 x 500], m2: [274 x 250] at /opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/THC/generic/THCTensorMathBlas.cu:268

Thanks.