DSL-Lab / SwinGNN

SwinGNN: diffusion model for graph generation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SwinGNN: Rethinking Permutation Invariance in Diffusion Models for Graph Generation

This repository contains the official implementation of the SwinGNN model in PyTorch. The SwinGNN paper was accepted at TMLR 2024/06.

Pointers: arxiv | code

Sampling processes of learned models on grid, protein and molecule datasets

Get started

Install dependencies

# option 1: python 3.8 venv
python -m venv venvscorenet
source venvscorenet/bin/activate
pip install -U pip
pip install cython==0.29.32 
pip install pomegranate 
pip install -r setup/requirements.txt

# option 2: conda
conda create -n scorenet python=3.8 
conda activate scorenet 
pip install -U pip 
pip install cython==0.29.32 
pip install pomegranate 
pip install -r setup/requirements.txt

# compile ORCA for orbit statistics evaluation
export PROJ_DIR=$(pwd)
cd evaluation/orca && g++ -O2 -std=c++11 -o orca orca.cpp && cd $PROJ_DIR

Please be aware that it may be crucial to first do pip install cython==0.29.32 and pip install pomegranate to prevent potential issues when installing the molsets (moses) package, which is necessary for calculating molecule metrics. Unfortunately, the original molsets (moses) package at https://github.com/molecularsets/moses is no longer actively maintained. It's important to install these dependencies in the correct order to avoid potential installation errors.

Setup datasets

# prepare datasets
python setup/gen_graph_data.py  # prepare various synthetic and real-world graph datasets
python setup/mol_preprocess.py --dataset ZINC250k  # prepare ZINC250k dataset
python setup/mol_preprocess.py --dataset QM9  # prepare QM9 dataset

Training command

Below we provide the training commands for SwinGNN on graph datasets and molecule datasets. Please refer to config/edm_swin_gnn for more training configurations.

# training cmds on graph dataset (without node/edge attributes), e.g., to train on grid dataset
python train.py -c config/edm_swin_gnn/grid_edm_swin_gnn_80.yaml --batch_size 10 -m=grid

# our code also supports DDP training
export NUM_GPUS=4
torchrun --nproc_per_node=$NUM_GPUS train.py -c config/edm_swin_gnn/grid_edm_swin_gnn_80.yaml --batch_size 40 --ddp -m=grid_ddp

# training cmds on molecule dataset (with node/edge attributes), e.g., to train on QM9 dataset
torchrun --nproc_per_node=$NUM_GPUS train.py -c config/edm_swin_gnn/qm9_edm_swin_gnn.yaml --feature_dims 60 --node_encoding one_hot --edge_encoding one_hot --batch_size 10240 --ddp -m qm9

Testing command

We release the checkpoints at Google Drive and oneDrive. Below we provide the sampling commands for SwinGNN on graph datasets and molecule datasets.

# ego-small
python eval.py -p swinGNN-checkpoints/ego_small/ego_small_dim_60/ego_small_dim_60.pth --use_ema 0.9 -m eval_ego_small 
python eval.py -p swinGNN-checkpoints/ego_small/ego_small_dim_96/ego_small_dim_96.pth --use_ema 0.99 -m eval_ego_small 

# community-small
python eval.py -p swinGNN-checkpoints/com_small/community_small_dim_60/community_small_dim_60.pth --use_ema 0.99 -m eval_com_small 
python eval.py -p swinGNN-checkpoints/com_small/community_small_dim_96/community_small_dim_96.pth --use_ema 0.95 -m eval_com_small

# grid
python eval.py -p swinGNN-checkpoints/grid/grid_dim_60/grid_dim_60.pth --use_ema 0.99 -m eval_grid 
python eval.py -p swinGNN-checkpoints/grid/grid_dim_96/grid_dim_96.pth --use_ema 0.95 -m eval_grid 

# dd-protein
python eval.py -p swinGNN-checkpoints/dd_protein/dd_dim_60/dd_dim_60.pth --use_ema 0.9999 -m eval_dd_protein
python eval.py -p swinGNN-checkpoints/dd_protein/dd_dim_96/dd_dim_96.pth --use_ema 0.9999 -m eval_dd_protein 

# qm9
python eval.py -p swinGNN-checkpoints/qm9/qm9_scalar_dim_60/qm9_scalar_dim_60.pth --use_ema 0.9999 -m eval_qm9
python eval.py -p swinGNN-checkpoints/qm9/qm9_scalar_dim_96/qm9_scalar_dim_96.pth --use_ema 0.9999 -m eval_qm9

# zinc250k
python eval.py -p swinGNN-checkpoints/zinc250k/zinc250k_scalar_dim_60/zinc250k_scalar_dim_60.pth --use_ema 0.9999 -m eval_zinc250k
python eval.py -p swinGNN-checkpoints/zinc250k/zinc250k_scalar_dim_96/zinc250k_scalar_dim_96.pth --use_ema 0.9999 -m eval_zinc250k

Due to the randomness in the sampling process or difference in hardware, the results may be slightly different from the reported results in the paper.

Citation and Acknowledgements

Bibtex.
If you find our code useful for your research, please cite the paper:

@article{yan2023swingnn,
  title={SwinGNN: Rethinking Permutation Invariance in Diffusion Models for Graph Generation},
  author={Yan, Qi and Liang, Zhengyang and Song, Yang and Liao, Renjie and Wang, Lele},
  journal={arXiv preprint arXiv:2307.01646},
  year={2023}
}

Acknowledgments and Disclosure of Funding.
This work was funded, in part, by NSERC DG Grants (No. RGPIN-2022-04636 and No. RGPIN-2019-05448), the NSERC Collaborative Research and Development Grant (No. CRDPJ 543676-19), the Vector Institute for AI, Canada CIFAR AI Chair, and Oracle Cloud credits. Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through the Digital Research Alliance of Canada alliance.can.ca, and companies sponsoring the Vector Institute www.vectorinstitute.ai/#partners, Advanced Research Computing at the University of British Columbia, and the Oracle for Research program. Additional hardware support was provided by John R. Evans Leaders Fund CFI grant and the Digital Research Alliance of Canada under the Resource Allocation Competition award.

Contact

Please submit a Github issue or contact qi.yan@ece.ubc.ca if you have any questions or find any bugs.

About

SwinGNN: diffusion model for graph generation

License:MIT License


Languages

Language:Python 75.6%Language:C++ 24.4%