Marchetz / MANTRA-CVPR20

Official Pytorch code for MANTRA - Memory Augmented Neural Trajectory Predictor (CVPR2020)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction

Official pytorch code for Mantra: Memory augmented networks for multiple trajectory prediction - CVPR2020

MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction" by Francesco Marchetti, Federico Becattini, Lorenzo Seidenari, and Alberto Del Bimbo.

Multiple trajectory prediction. Blue: past, red: futures.

Installation

To install the required packages, in a Python 3.6 environment just execute the following:

pip install -r requirements.txt

Dataset

We provide a dataloader for the KITTI dataset in dataset_invariance.py. The dataloader yields samples of (past, future) trajectories paired with a semantic map of the surrounding scene.

Training

To train MANTRA, first it is necessary to train the autoencoder, then to train the writing controller and finally to train the Iterative Refinment Module (IRM). Trainings can be monitored using tensorboard, logs are stored in the folder runs/(runs-pretrain/runs-createMem/runs-IRM). In the pretrained_model folder there are pretrained models of the different components (autoencoder, writing controller, MANTRA).

Training encoder-decoder model (autoencoder)

python train_ae.py

The autoencoder can be trained with the train_ae.py script. train_ae.py calls trainer_ae.py The model will be saved into the folder test/[current_date]. A pretrained model can be found in pretrained_models/model_AE/

Training writing controller

python train_controllerMem.py --model pretrained_autoencoder_model_path

The writing controller for the memory with autoencoder can be trained with train_controllerMem.py. train_controllerMem.py calls trainer_controllerMem.py. The path of a pretrained autoencoder model has to be passed to the script (it defaults to the pretrained model we provided). A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/

Training Iterative Refinement Module (IRM)

python train_IRM.py --model pretrained_autoencoder+controller_model_path

train_IRM.py calls trainer_IRM.py The script trains the IRM module that generates the final prediction based on the decoded trajectory and the context map. The paths of a pretrained autoencoder with writing controller model and populated memories have to be passed to the script (it defaults to the pretrained models we provided). A pretrained MANTRA model can be found in pretrained_models/model_complete/

Test

python test.py --model pretrained_complete_model_path --withIRM True/False --saved_memory True/False

test.py calls evaluate_MemNet.py This script generates metrics on the KITTI dataset using a trained models. We compute Average Displacement Error (ADE) and Final Displacement Error (FDE, also referred to as Error@K or Horizon Error).

Command line arguments

    --cuda                         Enable/Disable GPU device (default=True).
    --batch_size                   Number of samples that will be fed to MANTRA in one iteration (default=32).
    --past_len                     Past length (default=20).
    --future_len                   Future length (default=40).
    --preds                        Number of predictions generated by MANTRA model (default=5)
    --model                        Path of pretrained model for the evaluation (default='pretrained_models/MANTRA/model_MANTRA')
    --visualize_dataset            The system saves (in *folder_test/dataset_train* and *folder_test/dataset_test*) all examples
                                   of dataset.
    --saved_memory                 The system chooses which memories will be used in evaluation.
                                   If True, it will be loaded memories from 'memories_path' folder.
                                   If False, new memories will be generated. pairs of past-future will be decided by writing controller of model.
    --memories_path                This path will be used only if saved_memory flag is True.
    --withIRM                      The model generates predictions with/without Iterative Refinement Module.
    --saveImages                   The system saves in test folder examples of dataset with prediction generated by MANTRA.
                                   If None, it doesn't save any qualitative examples but only quantitative results.
                                   If 'All', it saves all examples.
                                   If 'Subset', it saves examples defined in index_qualitative.py (hand picked most significant samples)
                                   (default=None)
    --dataset_file                 Name of json file cointaining the dataset (default='kitti_dataset.json')
    --info                         Name of evaluation. It will use for name of the test folder (default='')

Citation

If you use our code or find it useful in your research, please cite the following paper:

@inproceedings{cvpr_2020,
 author = {Marchetti, Francesco and  Becattini, Federico and Seidenari, Lorenzo and Del Bimbo, Alberto},
 booktitle = {International Conference on Computer Vision and Pattern Recognition (CVPR)},
 publisher = {IEEE},
 title = {MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction},
 year = {2020}
}
@ARTICLE{Geiger2013IJRR,
  author = {Andreas Geiger and Philip Lenz and Christoph Stiller and Raquel Urtasun},
  title = {Vision meets Robotics: The KITTI Dataset},
  journal = {International Journal of Robotics Research (IJRR)},
  year = {2013}
}

License

logo

This source code is shared under the license CC-BY-NC-SA, please refer to the LICENSE file for more information.

This source code is only shared for R&D or evaluation of this model on user database.

Any commercial utilization is strictly forbidden.

For any utilization with a commercial goal, please contact contact_cs or bendahan

About

Official Pytorch code for MANTRA - Memory Augmented Neural Trajectory Predictor (CVPR2020)

License:Other


Languages

Language:Python 73.6%Language:Jupyter Notebook 26.4%