YuxuanIAIR / MSRL-master

Official code for AAAI 2023 paper "Multi-stream Representation Learning for Pedestrian Trajectory Prediction"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Multi-Stream Representation Learning for Pedestrian Trajectory Prediction

Official code for AAAI 2023 paper "Multi-Stream Representation Learning for Pedestrian Trajectory Prediction"

Framework

Multi-stream Representation Learning


CVAE for Multi-modal Prediction


Installation

Environment

  • Tested OS: Ubuntu 18.04 LTS / RTX3090

  • Python >= 3.7

  • PyTorch == 1.8.0

Dependencies

  1. Install PyTorch 1.8.0 with the correct CUDA version.

  2. Install the dependencies:

    pip install -r requirements.txt
    

Evaluation

Download the pre-trained models from GoogleDrive. Then unzip and put it under the project folder.

Run the following and then you will be able to reproduce the main results in our paper.

<dataset_name> can be eth, hotel, univ, zara1, zara2 or sdd.

python test.py --dataset <dataset_name> --gpu <gpu_id>

Training

This model requires two-stage training.

  1. Train the Multi-stream Representation Learning based CVAE model

    python trainvae.py --dataset <dataset_name> --gpu <gpu_id>
    
  2. Train the sampler model

    python trainsampler.py --dataset <dataset_name> --gpu <gpu_id>
    

You can modify the configuration by giving different parameters.

Acknowledgement

Thanks for the ETH-UCY data processing from SGCN and SDD data provided by PECNet.

Citation

If you find this repo helpful, please consider citing our paper

@inproceedings{wu2023multi,
  title={Multi-stream representation learning for pedestrian trajectory prediction},
  author={Wu, Yuxuan and Wang, Le and Zhou, Sanping and Duan, Jinghai and Hua, Gang and Tang, Wei},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={37},
  number={3},
  pages={2875--2882},
  year={2023}
}

About

Official code for AAAI 2023 paper "Multi-stream Representation Learning for Pedestrian Trajectory Prediction"


Languages

Language:Python 100.0%