YangLing0818 / IPDiff

[ICLR 2024] Protein-Ligand Interaction Prior for Binding-aware 3D Molecule Diffusion Models

Home Page:https://openreview.net/forum?id=qH9nrMNTIW

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Request for Model Checkpoints or configuration

SeungbeomLee opened this issue · comments

Hi,

Thank you for your work and for making the code available. I have been trying to reproduce the results from the Ipdiff paper using the model configuration provided in the GitHub repository, but unfortunately, I haven't been able to achieve the same results as reported in the paper.

Could you please release the trained checkpoints for the model or provide additional details on the model configuration? I am wondering if the model configuration on GitHub is the same as the model reported in the IPDiff paper.

It would be incredibly helpful for advancing my work.
Thank you!

Hi,

Thank you very much for your interest in this work.

We have provided the pre-trained model for both IPNet and IPDiff at pretrained-ipdiff and updated the ReadMe.md, please refer to it.

Thank you.

@ZerinHwang03
Hi.
Thanks for such an interesting project, but I'm having problems reproducing it using the instructions you provided.

  1. training the network using train.py does not work, the value of the loss function is very large and does not converge (around 10-100000), whereas in targetdiff this value is around 0.7-1.
  2. use the ipdiff and ipnet checkpoint you provided to re-observe the loss in the training set, the value is still very large (around 10-10000)
  3. ran into a bit of a problem when generating tests with the provided checkpoints:
    [2024-05-20 13:21:30,524::evaluate::INFO] Evaluate done! 100 samples in total.
    [2024-05-20 13:21:30,525::evaluate::INFO] mol_stable: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] atm_stable: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] recon_success: 1.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] eval_success: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] complete: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] JS bond distances of complete mols.
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|4: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-8|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-8|2: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|2: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|4: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|2: None
    /root/autodl-tmp/test/IPDiff-main/utils/evaluation/eval_bond_length.py:29: RuntimeWarning: invalid value encountered in true_divide
    bin_counts = np.array(bin_counts) / np.sum(bin_counts)
    [2024-05-20 13:21:30,526::evaluate::INFO] JSD_CC_2A: nan
    [2024-05-20 13:21:30,526::evaluate::INFO] JSD_All_12A: nan
    Traceback (most recent call last).
    File "eval_split.py", line 181, in
    atom_type_js = eval_atom_type.eval_atom_type_distribution(success_atom_types)
    File "/root/autodl-tmp/test/IPDiff-main/utils/evaluation/eval_atom_type.py", line 30, in eval_atom_type_distribution
    pred_atom_distribution[k] = pred_counter[k] / total_num_atoms
    ZeroDivisionError: division by zero
    I would like to ask if the code you provided is consistent with the code that yielded the results of the paper?Looking forward to your answer!

@SeungbeomLee
Did you succeed in reproducing it?

@ZerinHwang03 Hi. Thanks for such an interesting project, but I'm having problems reproducing it using the instructions you provided.

  1. training the network using train.py does not work, the value of the loss function is very large and does not converge (around 10-100000), whereas in targetdiff this value is around 0.7-1.
  2. use the ipdiff and ipnet checkpoint you provided to re-observe the loss in the training set, the value is still very large (around 10-10000)
  3. ran into a bit of a problem when generating tests with the provided checkpoints:
    [2024-05-20 13:21:30,524::evaluate::INFO] Evaluate done! 100 samples in total.
    [2024-05-20 13:21:30,525::evaluate::INFO] mol_stable: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] atm_stable: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] recon_success: 1.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] eval_success: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] complete: 0.0000
    [2024-05-20 13:21:30,525::evaluate::INFO] JS bond distances of complete mols.
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|4: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-8|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|1: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-8|2: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-6|2: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|4: None
    [2024-05-20 13:21:30,525::evaluate::INFO] JSD_6-7|2: None
    /root/autodl-tmp/test/IPDiff-main/utils/evaluation/eval_bond_length.py:29: RuntimeWarning: invalid value encountered in true_divide
    bin_counts = np.array(bin_counts) / np.sum(bin_counts)
    [2024-05-20 13:21:30,526::evaluate::INFO] JSD_CC_2A: nan
    [2024-05-20 13:21:30,526::evaluate::INFO] JSD_All_12A: nan
    Traceback (most recent call last).
    File "eval_split.py", line 181, in
    atom_type_js = eval_atom_type.eval_atom_type_distribution(success_atom_types)
    File "/root/autodl-tmp/test/IPDiff-main/utils/evaluation/eval_atom_type.py", line 30, in eval_atom_type_distribution
    pred_atom_distribution[k] = pred_counter[k] / total_num_atoms
    ZeroDivisionError: division by zero
    I would like to ask if the code you provided is consistent with the code that yielded the results of the paper?Looking forward to your answer!

Hi,

1&2). I have re-trained the model according to the code provided in this repo without loading pretrained IPDiff, the training log of first 5k iterations is provided below:

[2024-05-20 14:27:07,079::train::INFO] [Train] Iter 200 | Loss 0.861447 (pos 0.680345 | v 0.001811) | Lr: 0.000500 | Grad Norm: 25.784241
[2024-05-20 14:32:57,606::train::INFO] [Train] Iter 400 | Loss 1.099283 (pos 0.945749 | v 0.001535) | Lr: 0.000500 | Grad Norm: 35.337120
[2024-05-20 14:38:37,370::train::INFO] [Train] Iter 600 | Loss 1.600580 (pos 1.499844 | v 0.001007) | Lr: 0.000500 | Grad Norm: 36.945099
[2024-05-20 14:43:47,960::train::INFO] [Train] Iter 800 | Loss 0.394610 (pos 0.276877 | v 0.001177) | Lr: 0.000500 | Grad Norm: 6.206097
[2024-05-20 14:49:15,411::train::INFO] [Train] Iter 1000 | Loss 0.672863 (pos 0.618365 | v 0.000545) | Lr: 0.000500 | Grad Norm: 39.251366
[2024-05-20 14:54:47,547::train::INFO] [Train] Iter 1200 | Loss 0.899528 (pos 0.825398 | v 0.000741) | Lr: 0.000500 | Grad Norm: 34.874828
[2024-05-20 14:59:42,994::train::INFO] [Train] Iter 1400 | Loss 0.775737 (pos 0.621840 | v 0.001539) | Lr: 0.000500 | Grad Norm: 21.123680
[2024-05-20 15:05:39,927::train::INFO] [Train] Iter 1600 | Loss 0.822615 (pos 0.742547 | v 0.000801) | Lr: 0.000500 | Grad Norm: 23.954535
[2024-05-20 15:11:48,410::train::INFO] [Train] Iter 1800 | Loss 0.430038 (pos 0.391938 | v 0.000381) | Lr: 0.000500 | Grad Norm: 9.278733
[2024-05-20 15:16:37,886::train::INFO] [Train] Iter 2000 | Loss 0.289589 (pos 0.254304 | v 0.000353) | Lr: 0.000500 | Grad Norm: 7.107880
[2024-05-20 15:22:40,041::train::INFO] [Train] Iter 2200 | Loss 1.123865 (pos 1.032209 | v 0.000917) | Lr: 0.000500 | Grad Norm: 18.195499
[2024-05-20 15:28:09,404::train::INFO] [Train] Iter 2400 | Loss 0.671753 (pos 0.536350 | v 0.001354) | Lr: 0.000500 | Grad Norm: 9.016789
[2024-05-20 15:32:43,111::train::INFO] [Train] Iter 2600 | Loss 0.705554 (pos 0.589482 | v 0.001161) | Lr: 0.000500 | Grad Norm: 16.378878
[2024-05-20 15:37:55,351::train::INFO] [Train] Iter 2800 | Loss 1.046879 (pos 0.925731 | v 0.001211) | Lr: 0.000500 | Grad Norm: 39.363972
[2024-05-20 15:44:15,764::train::INFO] [Train] Iter 3000 | Loss 0.459049 (pos 0.387101 | v 0.000719) | Lr: 0.000500 | Grad Norm: 11.292586
[2024-05-20 15:49:20,947::train::INFO] [Train] Iter 3200 | Loss 1.288119 (pos 1.209991 | v 0.000781) | Lr: 0.000500 | Grad Norm: 33.929810
[2024-05-20 15:54:35,396::train::INFO] [Train] Iter 3400 | Loss 0.333837 (pos 0.286994 | v 0.000468) | Lr: 0.000500 | Grad Norm: 2.009230
[2024-05-20 15:59:09,997::train::INFO] [Train] Iter 3600 | Loss 1.066855 (pos 0.905395 | v 0.001615) | Lr: 0.000500 | Grad Norm: 41.150974
[2024-05-20 16:05:01,452::train::INFO] [Train] Iter 3800 | Loss 0.698866 (pos 0.562015 | v 0.001369) | Lr: 0.000500 | Grad Norm: 54.497398
[2024-05-20 16:09:26,520::train::INFO] [Train] Iter 4000 | Loss 0.922557 (pos 0.819029 | v 0.001035) | Lr: 0.000500 | Grad Norm: 40.375744
[2024-05-20 16:13:58,687::train::INFO] [Train] Iter 4200 | Loss 1.498024 (pos 1.418779 | v 0.000792) | Lr: 0.000500 | Grad Norm: 51.309418
[2024-05-20 16:19:04,225::train::INFO] [Train] Iter 4400 | Loss 1.024247 (pos 0.977453 | v 0.000468) | Lr: 0.000500 | Grad Norm: 8.381104
[2024-05-20 16:24:12,817::train::INFO] [Train] Iter 4600 | Loss 1.225165 (pos 1.155714 | v 0.000695) | Lr: 0.000500 | Grad Norm: 150.647873
[2024-05-20 16:29:04,562::train::INFO] [Train] Iter 4800 | Loss 1.230983 (pos 1.128690 | v 0.001023) | Lr: 0.000500 | Grad Norm: 32.593872
[2024-05-20 16:34:33,317::train::INFO] [Train] Iter 5000 | Loss 0.501395 (pos 0.289580 | v 0.002118) | Lr: 0.000500 | Grad Norm: 46.966637
[2024-05-20 16:36:59,452::train::INFO] [Validate] Iter 05000 | Loss 0.999039 | Loss pos 0.934050 | Loss v 0.649892 e-3 | Avg atom auroc 0.920940

It works on my conda environment and device (3090Ti). Please check the conda environment and ensure the config file in the config directory are modified correctly.

3). The evaluation results on my device is correct, and I'm not sure what issues you're encountering. Please check the path of dataset and sampled results in the config file training.yaml, sampling.yaml and eval_split.py and ensure them are modified correctly.

Thanks.

@ZerinHwang03
Thanks for your response. I checked the conda environment and all but a very few package versions are consistent, here is my training process:
root@autodl-container-4d6411b93c-86c0c4e8:~/autodl-tmp/test/IPDiff-main# python train.py
[2024-05-21 11:43:01,452::train::INFO] Namespace(config='./configs/training.yml', device='cuda', logdir='./logs', tag='', train_report_iter=100)
[2024-05-21 11:43:01,453::train::INFO] {'data': {'name': 'pl', 'path': './datasets/crossdocked_v1.1_rmsd1.0', 'split': './datasets/crossdocked_pocket10_pose_split.pt', 'transform': {'ligand_atom_mode': 'add_aromatic', 'random_rot': False}}, 'net_cond': {'ckpt_path': './pretrained_models/ipnet', 'hidden_dim': 128}, 'model': {'cond_dim': 128, 'model_mean_type': 'C0', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'v_beta_schedule': 'cosine', 'v_beta_s': 0.01, 'num_diffusion_timesteps': 1000, 'loss_v_weight': 100.0, 'sample_time_method': 'symmetric', 'time_emb_dim': 0, 'time_emb_mode': 'simple', 'center_pos_mode': 'protein', 'node_indicator': True, 'model_type': 'uni_o2', 'num_blocks': 1, 'num_layers': 9, 'hidden_dim': 128, 'n_heads': 16, 'edge_feat_dim': 4, 'num_r_gaussian': 20, 'knn': 32, 'num_node_types': 8, 'act_fn': 'relu', 'norm': True, 'cutoff_mode': 'knn', 'ew_net_type': 'global', 'num_x2h': 1, 'num_h2x': 1, 'r_max': 10.0, 'x2h_out_fc': False, 'sync_twoup': False}, 'train': {'seed': 2021, 'batch_size': 4, 'num_workers': 4, 'n_acc_batch': 1, 'max_iters': 1000000, 'val_freq': 1000, 'pos_noise_std': 0.1, 'max_grad_norm': 8.0, 'bond_loss_weight': 1.0, 'optimizer': {'type': 'adam', 'lr': 0.0005, 'weight_decay': 0, 'beta1': 0.95, 'beta2': 0.999}, 'scheduler': {'type': 'plateau', 'factor': 0.6, 'patience': 10, 'min_lr': 1e-06}}}
[2024-05-21 11:43:01,454::train::INFO] Loading dataset...
[2024-05-21 11:43:01,473::train::INFO] Training: 99990 Validation: 100
[2024-05-21 11:43:01,474::train::INFO] Building model...
Restored from ./pretrained_models/ipnet with 1 missing and 2 unexpected keys
Missing Keys: ['FusionGraph.0.lin.weight']
Unexpected Keys: ['FusionGraph.0.lin_src.weight', 'FusionGraph.0.lin_dst.weight']
protein feature dim: 27 ligand feature dim: 13
[2024-05-21 11:43:07,004::train::INFO] # trainable parameters: 2.8576 M
[2024-05-21 11:43:38,472::train::INFO] [Train] Iter 100 | Loss 96.282211 (pos 95.670990 | v 0.006112) | Lr: 0.000500 | Grad Norm: 358.884033
[2024-05-21 11:44:10,095::train::INFO] [Train] Iter 200 | Loss 138.689331 (pos 138.219299 | v 0.004700) | Lr: 0.000500 | Grad Norm: 1438.796509
[2024-05-21 11:44:40,733::train::INFO] [Train] Iter 300 | Loss 41.616707 (pos 41.034042 | v 0.005827) | Lr: 0.000500 | Grad Norm: 20765.884766
[2024-05-21 11:45:10,977::train::INFO] [Train] Iter 400 | Loss 464.610779 (pos 464.168579 | v 0.004422) | Lr: 0.000500 | Grad Norm: 107119.273438
[2024-05-21 11:45:42,385::train::INFO] [Train] Iter 500 | Loss 255.827133 (pos 255.262543 | v 0.005646) | Lr: 0.000500 | Grad Norm: 187880.984375
[2024-05-21 11:46:13,614::train::INFO] [Train] Iter 600 | Loss 182.836578 (pos 182.240417 | v 0.005962) | Lr: 0.000500 | Grad Norm: 71160.578125
[2024-05-21 11:46:43,976::train::INFO] [Train] Iter 700 | Loss 77.955620 (pos 77.129745 | v 0.008259) | Lr: 0.000500 | Grad Norm: 686.155029
[2024-05-21 11:47:14,988::train::INFO] [Train] Iter 800 | Loss 84.753128 (pos 84.115433 | v 0.006377) | Lr: 0.000500 | Grad Norm: 18564.908203
[2024-05-21 11:47:45,911::train::INFO] [Train] Iter 900 | Loss 1084.949341 (pos 1084.183472 | v 0.007659) | Lr: 0.000500 | Grad Norm: 1195725.250000
[2024-05-21 11:48:17,372::train::INFO] [Train] Iter 1000 | Loss 3077.959473 (pos 3077.344971 | v 0.006146) | Lr: 0.000500 | Grad Norm: 2227307.000000
Validate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:30<00:00, 1.22s/it]
atom: (6, False) auc roc: 0.5000
atom: (6, True) auc roc: 0.5000
atom: (7, False) auc roc: 0.5000
atom: (7, True) auc roc: 0.5000
atom: (8, False) auc roc: 0.5000
atom: (8, True) auc roc: 0.5000
atom: (9, False) auc roc: 0.5000
atom: (15, False) auc roc: 0.5000
atom: (16, False) auc roc: 0.5000
atom: (16, True) auc roc: 0.5000
atom: (17, False) auc roc: 0.5000
[2024-05-21 11:48:47,857::train::INFO] [Validate] Iter 01000 | Loss 171204.422686 | Loss pos 171199.732366 | Loss v 46.895633 e-3 | Avg atom auroc 0.500000
[2024-05-21 11:48:47,858::train::INFO] [Validate] Best val loss achieved: 171204.422686
[2024-05-21 11:49:18,895::train::INFO] [Train] Iter 1100 | Loss 125.158424 (pos 124.654358 | v 0.005041) | Lr: 0.000500 | Grad Norm: 17831.619141
[2024-05-21 11:49:48,922::train::INFO] [Train] Iter 1200 | Loss 46.375423 (pos 45.016602 | v 0.013588) | Lr: 0.000500 | Grad Norm: 88107.304688
[2024-05-21 11:50:19,333::train::INFO] [Train] Iter 1300 | Loss 134.130661 (pos 133.674942 | v 0.004557) | Lr: 0.000500 | Grad Norm: 208998.437500
[2024-05-21 11:50:49,612::train::INFO] [Train] Iter 1400 | Loss 56.073097 (pos 55.330811 | v 0.007423) | Lr: 0.000500 | Grad Norm: 65657.328125
[2024-05-21 11:51:20,599::train::INFO] [Train] Iter 1500 | Loss 62.850895 (pos 62.448208 | v 0.004027) | Lr: 0.000500 | Grad Norm: 2004.350098
[2024-05-21 11:51:52,601::train::INFO] [Train] Iter 1600 | Loss 1281.011108 (pos 1280.392578 | v 0.006185) | Lr: 0.000500 | Grad Norm: 749706.625000
[2024-05-21 11:52:24,549::train::INFO] [Train] Iter 1700 | Loss 83.051231 (pos 82.621483 | v 0.004297) | Lr: 0.000500 | Grad Norm: 2012.388184
[2024-05-21 11:52:55,328::train::INFO] [Train] Iter 1800 | Loss 48.049915 (pos 47.353981 | v 0.006959) | Lr: 0.000500 | Grad Norm: 16269.761719
[2024-05-21 11:53:25,960::train::INFO] [Train] Iter 1900 | Loss 111.246574 (pos 108.708702 | v 0.025379) | Lr: 0.000500 | Grad Norm: 8672.382812
[2024-05-21 11:53:56,076::train::INFO] [Train] Iter 2000 | Loss 11823.634766 (pos 11822.794922 | v 0.008396) | Lr: 0.000500 | Grad Norm: 5034614.500000

Restored from ./pretrained_models/ipnet with 1 missing and 2 unexpected keys
Missing Keys: ['FusionGraph.0.lin.weight']
Unexpected Keys: ['FusionGraph.0.lin_src.weight', 'FusionGraph.0.lin_dst.weight']

Hi,

It seems that you did not load the IPNet correctly, please check the file graphbap/bapnet.py

And this is my training log:

[2024-05-20 14:19:40,787::train::INFO] Namespace(config='/home/huangzl/Data/workspace/IPDiff-main/configs/training.yml', device='cuda', logdir='/home/huangzl/Data/workspace/IPDiff-main/logs', tag='', train_report_iter=200)
[2024-05-20 14:19:40,787::train::INFO] {'data': {'name': 'pl', 'path': '/home/huangzl/Data/datasets/molecule/targetdiff/crossdocked_v1.1_rmsd1.0', 'split': '/home/huangzl/Data/datasets/molecule/targetdiff/crossdocked_pocket10_pose_split.pt', 'transform': {'ligand_atom_mode': 'add_aromatic', 'random_rot': False}}, 'net_cond': {'ckpt_path': '/home/huangzl/Data/workspace/IPDiff-main/pretrained_models/ipnet', 'hidden_dim': 128}, 'model': {'cond_dim': 128, 'model_mean_type': 'C0', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'v_beta_schedule': 'cosine', 'v_beta_s': 0.01, 'num_diffusion_timesteps': 1000, 'loss_v_weight': 100.0, 'sample_time_method': 'symmetric', 'time_emb_dim': 0, 'time_emb_mode': 'simple', 'center_pos_mode': 'protein', 'node_indicator': True, 'model_type': 'uni_o2', 'num_blocks': 1, 'num_layers': 9, 'hidden_dim': 128, 'n_heads': 16, 'edge_feat_dim': 4, 'num_r_gaussian': 20, 'knn': 32, 'num_node_types': 8, 'act_fn': 'relu', 'norm': True, 'cutoff_mode': 'knn', 'ew_net_type': 'global', 'num_x2h': 1, 'num_h2x': 1, 'r_max': 10.0, 'x2h_out_fc': False, 'sync_twoup': False}, 'train': {'seed': 2021, 'batch_size': 4, 'num_workers': 4, 'n_acc_batch': 1, 'max_iters': 1000000, 'val_freq': 5000, 'pos_noise_std': 0.1, 'max_grad_norm': 8.0, 'bond_loss_weight': 1.0, 'optimizer': {'type': 'adam', 'lr': 0.0005, 'weight_decay': 0, 'beta1': 0.95, 'beta2': 0.999}, 'scheduler': {'type': 'plateau', 'factor': 0.6, 'patience': 10, 'min_lr': 1e-06}}}
[2024-05-20 14:19:40,791::train::INFO] Loading dataset...
[2024-05-20 14:19:40,938::train::INFO] Training: 99990 Validation: 100
[2024-05-20 14:19:40,940::train::INFO] Building model...
[2024-05-20 14:20:25,046::train::INFO] # trainable parameters: 2.8576 M
Restored from /home/huangzl/Data/workspace/IPDiff-main/pretrained_models/ipnet with 0 missing and 0 unexpected keys
protein feature dim: 27 ligand feature dim: 13

Thanks.

@ZerinHwang03 Thanks for your answer, I think it might be a problem with the torch-geometric version, can you provide the version number of the package, I didn't find it in ipdiff.yml. Thank you!

@ZerinHwang03 Thanks for your answer, I think it might be a problem with the torch-geometric version, can you provide the version number of the package, I didn't find it in ipdiff.yml. Thank you!

torch-geometric in ./miniconda3/envs/avina/lib/python3.8/site-packages (2.0.4)

Thanks, it's a different version of torch-geometric that caused the problem, the model can be trained normally now!