Tsformer pretrain question

Jimmy-7664 opened this issue · comments

In the file, it uses method "test" to do the test while training, however, I found that it only use the last batch of the test_dataloader. Is there something wrong with the for loop? I'm I wrong?
` @torch.no_grad()
def test(self):
"""Evaluate the model.

        train_epoch (int, optional): current epoch if in training process.

    for _, data in enumerate(self.test_data_loader):
        forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
    # re-scale data
    prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
    real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
    # metrics
    for metric_name, metric_func in self.metrics.items():
        metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
        self.update_epoch_meter("test_"+metric_name, metric_item.item())`

Thanks for your question.
I have checked and this is a bug...
But this doesn't affect the final result because this metric is not used for anything.
Thanks again for reporting this bug, I'll fix it right away!

Thanks for your reply, so the right test method is the test method in the file "",` def test(self):
"""Evaluate the model.

        train_epoch (int, optional): current epoch if in training process.

    # test loop
    prediction = []
    real_value = []
    for _, data in enumerate(self.test_data_loader):
        forward_return = self.forward(data, epoch=None, iter_num=None, train=False)
        prediction.append(forward_return[0])        # preds = forward_return[0]
        real_value.append(forward_return[1])        # testy = forward_return[1]
    prediction =, dim=0)
    real_value =, dim=0)
    # re-scale data
    prediction = SCALER_REGISTRY.get(self.scaler["func"])(
        prediction, **self.scaler["args"])
    real_value = SCALER_REGISTRY.get(self.scaler["func"])(
        real_value, **self.scaler["args"])
    # summarize the results.
    # test performance of different horizon
    for i in self.evaluation_horizons:
        # For horizon i, only calculate the metrics **at that time** slice here.
        pred = prediction[:, i, :]
        real = real_value[:, i, :]
        # metrics
        metric_results = {}
        for metric_name, metric_func in self.metrics.items():
            metric_item = self.metric_forward(metric_func, [pred, real])
            metric_results[metric_name] = metric_item.item()
        log = "Evaluate best model on test data for horizon " + \
            "{:d}, Test MAE: {:.4f}, Test RMSE: {:.4f}, Test MAPE: {:.4f}"
        log = log.format(
            i+1, metric_results["MAE"], metric_results["RMSE"], metric_results["MAPE"])
    # test performance overall
    for metric_name, metric_func in self.metrics.items():
        metric_item = self.metric_forward(metric_func, [prediction, real_value])
        self.update_epoch_meter("test_"+metric_name, metric_item.item())
        metric_results[metric_name] = metric_item.item()

I'm I right?


No, the test function in the base_tsf_runner is designed for the Time Series Forecasting (TSF) problem, which is not compatible with the reconstruction task in the pre-training stage.
Actually, I think you can fix this by adding a Tab for lines 84~90 like:

    def test(self):
        """Evaluate the model.

            train_epoch (int, optional): current epoch if in training process.

        for _, data in enumerate(self.test_data_loader):
            forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
            # re-scale data
            prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
            real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
            # metrics
            for metric_name, metric_func in self.metrics.items():
                metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
                self.update_epoch_meter("test_"+metric_name, metric_item.item())

Thanks for your answering, I got your idea. : )


This bug is now fixed. Thanks again for your report!