SYLan2019 / MH-ASTIGCN

MULTI HEAD SELF-ATTENTION BASED SPATIAL-TEMPORAL INFORMATION GRAPH CONVOLUTIONAL NETWORKS FOR TRAFFIC FLOW FORECASTING

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MH-ASTIGCN

Note: The latest version is in the latest branch, i.e. https://github.com/SYLan2019/MH-ASTIGCN/tree/latest

Multi-Head Self-Attention Based Spatial-Temporal Information Graph Convolutional Networks for Traffic Flow Forecasting

model architecture

References

Requirements

  • python >= 3.5
  • scipy
  • tensorboard
  • pytorch

Datasets

Step 1: MH-ASTIGCN is implemented on those several public traffic datasets.

Step 2: Process dataset

  • on PEMS04 dataset

    python prepareData.py --config configurations/PEMS04.conf

Temporal Information Graph Construction

If traffic data is available, TIG could also be generated by code:

cd ./data/
python Temporal_Graph_gen.py

The shape of input traffic data should be "(Total_Time_Steps, Node_Number)". For example, in PEMS08 dataset, it has 170 roads and 62 days data. Thus its shape is (62*288, 170).

The calculation uses CPU, which should be prepared for enough computation resources.

Test

  • on PEMS04 dataset (use our trained network parameters)

    python test.py --config configurations/PEMS04.conf   

Configuration

The configuration file config.conf contains two parts: Data, Training:

Data

  • adj_filename: path of the adjacency matrix file
  • graph_signal_matrix_filename: path of graph signal matrix file
  • STIG_filename:path of the Spatial-Temporal Information Graph file
  • TSG_filename:path of the Temporal Similarity Graph file
  • num_of_vertices: number of vertices
  • points_per_hour: points per hour, in our dataset is 12
  • num_for_predict: points to predict, in our model is 12

Training

  • graph: select the graph structure, G or STIG, G stands for adjacency graph, STIG stands for Spatial-Temporal Information Graph
  • ctx: set ctx = cpu, or set gpu-0, which means the first gpu device
  • epochs: int, epochs to train
  • learning_rate: float, like 0.0001
  • batch_size: int
  • num_of_weeks: int, how many weeks' data will be used
  • num_of_days: int, how many days' data will be used
  • num_of_hours: int, how many hours' data will be used
  • n_heads: int, number of temporal att heads will be used
  • d_k: int, the dimensions of the Q, K, and V vectors will be used
  • d_model: int, d_E
  • K: int, K-order chebyshev polynomials (number of spatial att heads) will be used

About

MULTI HEAD SELF-ATTENTION BASED SPATIAL-TEMPORAL INFORMATION GRAPH CONVOLUTIONAL NETWORKS FOR TRAFFIC FLOW FORECASTING


Languages

Language:Python 100.0%