henrymartin1 / dsml20_attention

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to reproduce results

  • Download the data: Go to the IARAI website to download the traffic4cast 2019 competition data, upack it.

  • Train or download models:

    • If necessary, update the data paths in model_training/unet_config.py and model_training/graphnets_config.py

    • run model_training/unet_training.py, kipfdepth1_training.py, model_training/kipfdepth2_training.py and model_training/graphresnet_training.py from within the main directory (e.g., python ./model_training/unet_training.py) to train the models (pretrained models are available here).

  • Run the generalization experiment:

    • If necessary, update the data paths in experiment/generalization_config.py
    • If you want to test your own models, update config['model_tuple_list']
    • run experiment/generalization.py to test all models on all cities.
  • Plot:

    • To reproduce the plot output/performance_nb_params.pdfrun experiment/plot_performance_vs_nbparams.py

All scripts should be executed in the root folder of the repository. E.g., python experiment/generalization.py

Folder structure

  • Code:
    • experiment

      generalization.py: This script runs the main experiment and calculates the loss for all Moscow trained models on Istanbul and Berlin. The results are stored in output/data_generalization.p' as a dictionary.

    • model_training All scripts and configuration files necessary for the model training

      • *_training.py files are used to train the corresponding networks
      • graphnets_config.py has all necessary configurations for the training of the different graph networks
      • unet_config.py has all necessary configurations for the training of the different U-Nets
      • All training results are stored in /runs/graphnets or /runs/unets respectively
    • models

      • graph_models.py and unet.py contain the definitions for the different models used in this paper.
    • utils Helper functions for graph image-transformations and neural network training

  • Data and results:
    • data Default directory for the raw data. Raw data has to be downloaded from the IARAI website.

    • images Graphs used in the paper

    • output

      • data_generalization.p pickle file with the results of the generalization experiment
      • performance_nb_params.pdf Figure 4 from the Graph-ResNets for short-term traffic forecasts in almost unknown cities paper.
    • runs Folder that stores the tensorboard logs and the corresponding trained models. The trained models used in the paper are stored in PMLR_nets all newly trained UNets are stored in the unets folder and all newly trained graph networks are stored in the graphnets folder. Pretrained networks have to be downloaded from here

Notes

  • We use PyTorch geometric to implement graph neural networks
  • For reliable results the batch-size for graph networks must be set to 1

About


Languages

Language:Python 98.3%Language:Jupyter Notebook 1.7%