PatriciaXiao / LG-ODE

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

LG-ODE

LG-ODE is an overall framework for learning continuous multi-agent system dynamics from irregularly-sampled partial observations considering graph structure.

You can see our Neurips 2020 paper Learning Continuous System Dynamics from Irregularly-Sampled Partial Observations for more details.

This implementation of LG-ODE is based on Pytorch Geometric API.

Data Generation

Generate simulated datasets (spring, charged particles) by running:

cd data
python generate_dataset.py 
python generate_dataset.py --num-train 1000 --num-test 200

This generates the springs dataset, use --simulation charged for charged particles.

As simulated data is too large, we provide a toy-data from spring dataset and can be found under data/example_data

Motion dataset can be downloaded CMU MoCap

Setup

This implementation is based on pytorch_geometric. To run the code, you need the following dependencies:

pip install torchdiffeq

requires the latest version Pytorch.

pip install torch_geometric
pip install git+https://github.com/rusty1s/pytorch_sparse.git
  • torch_scatter
pip install torch_scatter

Usage

Execute the following scripts to train on the sampled data from spring system:

python run_models.py

There are some key options of this scrips:

  • --sample-percent-train: This is the observed percentage in your training data.

  • --sample-percent-test: This is the observed percentage in your testing data.

  • --solver : This is for choosing your ODE Solver.

  • --extrap: Set True to run in the extrapolation mode, otherwise run in the interpolation mode.

The details of other optional hyperparameters can be found in run_models.py.

Citation

Please consider citing the following paper when using our code for your application.

@inproceedings{LG-ODE,
  title={Learning Continuous System Dynamics from Irregularly-Sampled Partial Observations},
  author={Zijie Huang and Yizhou Sun and Wei Wang},
  booktitle={Advances in Neural Information Processing Systems},
  year={2020}
}

Note

environment problems

If encounter the following problem:

>>> import torch_sparse
libc++abi.dylib: terminating with uncaught exception of type std::length_error: vector

After guarantee the latest version gcc:

xcode-select --install

Try this:

pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html

About


Languages

Language:Python 100.0%