sunfanyunn / IMMA

Official code for "Interaction Modeling with Multiplex Attention" (NeurIPS 2022)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Interaction Modeling with Multiplex Attention

Authors: Fan-Yun Sun, Isaac Kauvar, Ruohan Zhang, Jiachen Li, Mykel Kochenderfer, Jiajun Wu, Nick Haber

Abstract: Modeling multi-agent systems requires understanding how agents interact. Such systems are often difficult to model because they can involve a variety of types of interactions that layer together to drive rich social behavioral dynamics. Here we introduce a method for accurately modeling multi-agent systems. We present Interaction Modeling with Multiplex Attention (IMMA), a forward prediction model that uses a multiplex latent graph to represent multiple independent types of interactions and attention to account for relations of different strengths. We also introduce Progressive Layer Training, a training strategy for this architecture. We show that our approach outperforms state-of-the-art models in trajectory forecasting and relation inference, spanning three multi-agent scenarios: social navigation, cooperative task achievement, and team sports. We further demonstrate that our approach can improve zero-shot generalization and allows us to probe how different interactions impact agent behavior.

This repository contains the codes for our paper, which is accepted at NeurIPs 2022. For more details, please refer to the paper (arxiv, openreview).

Environment Setup

  1. Install Python-RVO2 library
  2. Install socialforce library
  3. Install necessary packages with pip
pip install -r requirements.txt

Data Setup

Social Navigation Environment

This is a simulated environment inspired by https://github.com/vita-epfl/CrowdNav. After installing necessary dependencies, refer to the following sample commands to start the simulation.

randomseed=17
dataset_size=100000
obs_frames=24
rollouts=10

cd data_utils/socialnav
python generate_dataset.py --dataset_size $dataset_size \
                           --randomseed $randomseed \
                           --obs_frames ${obs_frames} \
                           --rollouts ${rollouts}

The resulting dataset will be stored at datasets/*.tensor. You can make modifications to the config file dat_utils/socialnav/configs/default.py to change the simulation setting.

To inspect and interact with the environment (control the embodied agent with your arrow keys):

cd data_utils/socialnav
python human_play.py

PHASE

The preprocessed dataset is under datasets/phase/collab. To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

NBA dataset

Download the preprocessed dataset here (or run gdown 1eJbDHy3fOHfzOStf-jSuYCz_YQloQU3s) and place it under datasets. Alternatively, you can create your own dataset from raw sportVU logs (refer to this repository or the code under data_utils/bball) To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

Training and Evaluation

Find sample commands at run_socialnav.sh, run_phase.sh and run_bball.sh.

Progressive Layered Training

Loss curve visualized over the course of training IMMA with PLT on the NBA dataset. New layers are added after the model "converges". Teaser image

Citation

If you find the code or paper useful for your research, please cite our paper:

@article{sun2022interaction,
  title={Interaction Modeling with Multiplex Attention},
  author={Sun, Fan-Yun and Kauvar, Isaac and Zhang, Ruohan and Li, Jiachen and Kochenderfer, Mykel and Wu, Jiajun and Haber, Nick},
  journal={Advances in Neural Information Processing Systems},
  year={2022}
}

Acknwledgement

In htis project we use (parts of) the implementations from the following works:

We thank the authors for open sourcing their methods.

About

Official code for "Interaction Modeling with Multiplex Attention" (NeurIPS 2022)


Languages

Language:Python 97.1%Language:HTML 2.5%Language:Shell 0.3%