This is a PyTorch implementation of T-GCN in the following paper: T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction.
A stable version of this repository can be found at the official repository.
Notice that the original implementation is in TensorFlow, which performs a tiny bit better than this implementation for now.
- numpy
- matplotlib
- pandas
- torch
- pytorch-lightning>=1.3.0
- torchmetrics>=0.3.0
- python-dotenv
# GCN
python main.py --model_name GCN --max_epochs 3000 --learning_rate 0.001 --weight_decay 0 --batch_size 64 --hidden_dim 100 --settings supervised --gpus 1
# GRU
python main.py --model_name GRU --max_epochs 3000 --learning_rate 0.001 --weight_decay 1.5e-3 --batch_size 64 --hidden_dim 100 --settings supervised --gpus 1
# T-GCN
python main.py --model_name TGCN --max_epochs 3000 --learning_rate 0.001 --weight_decay 0 --batch_size 32 --hidden_dim 64 --loss mse_with_regularizer --settings supervised --gpus 1
You can also adjust the --data
, --seq_len
and --pre_len
parameters.
Run tensorboard --logdir lightning_logs/version_0
to monitor the training progress and view the prediction results.