zezhishao / STEP

Code for our SIGKDD'22 paper Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

用自己的数据集来 预训练 模型时,如何配置

emanlee opened this issue · comments

作者您好!

我们想用自己的数据集来 预训练 模型。
我们看了一下 TSFormer_METR-LA.py 文件,里面有些配置不明白。因此,请教一下。

CFG.DATASET_INPUT_LEN = 288 * 7 这个288和7分别表示什么,谢谢!

CFG.MODEL.PARAM = {
"patch_size":12, ############ 请问这个12表示输出时间步吗?
"in_channel":1,
"embed_dim":96,
"num_heads":4,
"mlp_ratio":4,
"dropout":0.1,
"num_token":288 * 7 / 12, ############ 请问这个地方为什么除以12?
"mask_ratio":0.75,
"encoder_depth":4,
"decoder_depth":1,
"mode":"pre-train"
}

从 TSFormer_METR-LA.py 文件看不出原始的输入数据的文件名以及文件扩展名,请问输入数据应该放到哪个文件夹,并且,文件取名有什么要求吗? 是不是要这样取名 METR-LA.h5 ? 还是要类似于 scaler_in2016_out12.pkl ?

谢谢!

commented
  1. 2887是历史时间步长。288个时间片在METR-LA中涵盖了一天的数据,因此2887代表着用历史七天数据进行预训练。这样方便显式捕捉“天”周期性和“周”周期性。
  2. patch_size不是输出步长,预训练阶段是一个重构任务,输出步长不参与运算。patch_size是切片大小。具体细节可以看论文呢,有详细描述如何做切片(patchify)
  3. num_token是token的数量。这个在论文中有专门提到,使用patch作为基本输入单元,而不是常用的point。因此,真正输入到模型的token的序列长度是288*7/12(序列长度/patch长度=patch数量=token数量)。
  4. 添加新的数据集的话,这个比较困难。你需要学习一下generate_training_data.py,仿照它的逻辑产生符合BasicTS规范的自己的数据集。产生完之后就比较简单了,在配置文件中改一下数据集名称,他就会自动去寻找然后读取了。

十分感谢您的解释!