Reimplementation of the Neural Relation Inference proposed in the following paper: Kipf, Thomas, et al. "Neural relational inference for interacting systems." International Conference on Machine Learning. PMLR, 2018.
Results figures in Neural Relational Inference for Interacting Systems
Recommend using conda virtual environment. An environment.yml
file has been set up. Simply run the following command to setup the required environment.
conda env create --name recoveredenv --file environment.yml
Next, create a local package (named src
). Notice that -e
indicates that the package is editable (no need to reinstall ) and .
indicates the current folder. This approach takes the advantage of python package system.
pip install -e.
Use scripts/generate_dataset.py
to generate simulation data. You can use simulation data for both training and testing. All model input data will be saved in the data
folder. In the data
folder, we already provided two .npy
files ONLY for testing.
Run the following code to train the encoder and the decoder respectively. The best model (among all epochs) is obtained through validation. The best model will then be tested.
/scripts$ python train_enc.py
/scripts$ python train_dec.py
You can further adjust training arguments. For details, use python train_enc.py -h
.
Notice that GPU is not necessary for training. You can train the model in a short time on a CPU platform.
We provide run_decoder.py
and run_encoder.py
for generating trajectory based on trained model. The steps are the followings.
- You can train a new model or use existing models. All trained models are saved in the folder
saved_model
. To use the model, specify the model path in the arguments ofrun_decoder.py
andrun_encoder.py
correspondingly. - Specify datasets and network structural arguments in
run_decoder.py
andrun_encoder.py
. Run the script. - Run the script. The model output will be saved in the folder
saved_results
. - For decoder output, you can use
traj_plot.ipynb
to generate a gif visualization.
The visualization part of run_encoder.py
is still under consideration.
![]() |
![]() |
Ground truth trajectory | Prediction trajectory |