vita-epfl / social-nce

[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Social-NCE + CrowdNav

Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN

This is an official implementation for
Social NCE: Contrastive Learning of Socially-aware Motion Representations
Yuejiang Liu, Qi Yan, Alexandre Alahi, ICCV 2021

TL;DR: Contrastive Representation Learning + Negative Data Augmentations 🡲 Robust Neural Motion Models

[New] our more recent work on this topic:
Towards Robust and Adaptive Motion Forecasting: A Causal Representation Perspective, CVPR 2022.

Preparation

Setup environments follwoing the SETUP.md

Training & Evaluation

  • Behavioral Cloning (Vanilla)
    python imitate.py --contrast_weight=0.0 --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-baseline-data-0.50/policy_net.pth
    
  • Social-NCE + Conventional Negative Sampling (Local)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-local-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0-range-2.00/policy_net.pth
    
  • Social-NCE + Safety-driven Negative Sampling (Ours)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='event' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth
    
  • Method Comparison
    bash script/run_vanilla.sh && bash script/run_local.sh && bash script/run_snce.sh
    python utils/compare.py
    

Basic Results

Results of behavioral cloning with different methods.

Averaged results from the 150th to 200th epochs.

collision reward
Vanilla 12.7% ± 3.8% 0.274 ± 0.019
Local 19.3% ± 4.2% 0.240 ± 0.021
Ours 2.0% ± 0.6% 0.331 ± 0.003

Citation

If you find this code useful for your research, please cite our papers:

@inproceedings{liu2021social,
  title={Social nce: Contrastive learning of socially-aware motion representations},
  author={Liu, Yuejiang and Yan, Qi and Alahi, Alexandre},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  pages={15118--15129},
  year={2021}
}
@inproceedings{chen2019crowd,
  title={Crowd-robot interaction: Crowd-aware robot navigation with attention-based deep reinforcement learning},
  author={Chen, Changan and Liu, Yuejiang and Kreiss, Sven and Alahi, Alexandre},
  booktitle={International Conference on Robotics and Automation (ICRA)},
  pages={6015--6022},
  year={2019}
}

About

[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

License:BSD 2-Clause "Simplified" License


Languages

Language:Python 97.1%Language:Shell 2.9%