longyuanli / GRIN_NeurIPS21

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GRIN: Generative Relation and Intention Network for Multi-agent Trajectory Prediction

Official Implementation of Generative Relation and Intention Network (GRIN) in PyTorch and DGL.

Dependencies

  • torch==1.8.1
  • numpy==1.19.2
  • scipy==1.6.1
  • dgl_cu110==0.6.1
  • dgl==0.6.1
  • tensorboardX==2.2

Running the code

  1. Install all dependencies mentioned above

  2. Generate charged dataset for training (NBA dataset is available on [44])

python simulator.py --seed 0 --num_sample 5000 --filename train.npz
python simulator.py --seed 1 --num_sample 1000 --filename test.npz
python simulator.py --seed 2 --num_sample 1000 --filename valid.npz
  1. Train the model
bash train.sh
  1. Evaluate the model
bash eval.sh

About


Languages

Language:Python 99.3%Language:Shell 0.7%