ForestsKing / GraFITi

PyTorch implementation of GraFITi (GraFITi: Graphs for Forecasting Irregularly Sampled Time Series)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GraFITi

方式一

  1. 借助工具下载 GraFITi 官方代码./GraFITi/

  2. 安装 tsdm

    1. 下载 tsdm 官方代码./tsdm/
    2. 创建 conda 虚拟环境,注意 python=3.11
    3. ./GraFITi/tsdm 替换 ./tsdm/src/tsdm
    4. ./tsdm/src/tsdm/viz/_config.py 中的 USE_TEX: Final[bool] = matplotlib.checkdep_usetex(True) 改为 USE_TEX: Final[bool] = False
    5. 进入 ./tsdm/ 目录,执行 pip install -e .
  3. 修改 ./GraFITi/train_grafiti.py

    1. 创建模型存储目录

      if not os.path.exists('saved_models/'):
          os.makedirs('saved_models/')
    2. 修改优化器配置

      OPTIMIZER_CONFIG = {
          "lr": ARGS.learn_rate,
          "betas": ARGS.betas,
          "weight_decay": ARGS.weight_decay,
      }
    3. 如果需要,添加 tqdm 打印进度条

  4. 进入 ./GraFITi/ 目录,运行如下命令运行官方示例,如果提示缺包自行安装即可

python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12

方式二

  1. 下载本项目

  2. 创建 conda 虚拟环境,注意 python=3.11

  3. 进入 tsdm-main 目录,执行 pip install -e .

  4. 进入 ./GraFITi/ 目录,运行如下命令运行官方示例,如果提示缺包自行安装即可

    python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12

运行结果

About

PyTorch implementation of GraFITi (GraFITi: Graphs for Forecasting Irregularly Sampled Time Series)


Languages

Language:Jupyter Notebook 59.3%Language:Python 40.5%Language:C++ 0.1%Language:Shell 0.1%Language:Makefile 0.0%Language:Batchfile 0.0%Language:CMake 0.0%