unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.

Home Page:https://unit8co.github.io/darts/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] TFTModel returns StopIteration on decoder_vsn() call in forward()

giacomoguiduzzi opened this issue · comments

Describe the bug
When I try training an instance of TFTModel I see the model is not doing anything in the epochs log. All the epochs are run immediately with 0 batches computed. It turns out that, debugging the code during the forward() execution, a call to self.decoder_vsn() returns a StopIteration exception, thus breaking the training for that epoch. This is repeated until the training is completed (all the epochs have run). I'm using a custom MixedCovariatesSequentialDataset instance to override the default behaviour through fit_from_dataset(); the purpose of this is to override the __getitem__() method to return my own sliding windows.

To Reproduce
My current configuration is a TFTModel instance with input_chunk_length=64, output_chunk_length=1, batch_size=32 and a custom Trainer object defined as follows:

early_stopping = EarlyStopping(
    monitor="val_loss",
    patience=10,
    min_delta=0.05,
    verbose=True,
    mode="min",
)
trainer = Trainer(
    devices=devices_indexes,
    callbacks=[early_stopping],
    max_epochs=model_n_epochs,
)

Expected behavior
The TFTModel trains correctly as the other models.

System (please complete the following information):

  • Python version: 3.9.18
  • darts version 0.28.0

Additional context
I initially thought this problem came from the EarlyStopping instance since it raised a RuntimeError reporting that it could not evaluate "val_loss" as it was not logged as a metric. After removing the EarlyStopping callback I noticed this behaviour.

Hi @giacomoguiduzzi, could you provide a minimal reproducible example?
We would need to see how you defined the custom Dataset. And then you can use some toy data to run the model and reproduce the issue.

Hi @dennisbader, I thought of opening the issue in the meantime just in case you had any idea of similar issues; I'm working on the example to provide you, I'll paste it here as soon as I have it.
I also forgot to link the line where the StopIteration occurs:

embeddings_varying_decoder, decoder_sparse_weights = self.decoder_vsn(

Thanks in advance!

Hi @dennisbader, as for the other issue I have created a repo: https://github.com/giacomoguiduzzi/tftmodel_bug_example. In this case the script does not terminate correctly as the model does not output any prediction. If you run the debugger and evaluate the forward() function in pl_forecasting_module.py:49, you'll see that the result is a StopIteration and it comes from the line I linked in the previous comment:

embeddings_varying_decoder, decoder_sparse_weights = self.decoder_vsn(
    x=embeddings_varying_decoder,
    context=static_context_expanded[:, encoder_length:],
)

Let me know what you think about it. Looking forward to your kind response.

EDIT: I forgot to say that in the example script I had to comment out the EarlyStopping object I normally use because it can't find the 'val_loss' metric. I believe this has to do with the model not completely running the forward(), so nothing is being logged.

Hi @giacomoguiduzzi, and sorry for the late response. The problem is that TFTModel can only be used with future_covariates information. E.g. you must return the historic_future_covariates and future_covariates in the __getitem__ of your datasets.

We have option add_relative_index which generates some dummy/placeholder future covariates for you, in case you don't have any future information.

model = TFTModel(..., add_relative_index=True) should work.

P.s. if you're using Darts version 0.30.0 already: We now added the sample weights to our training datasets. This requires to return an additional value in the __getitem__ method.

You can simply add a line to the return

        return (
            past_target,
            past_covariate,
            historic_future_covariate,
            future_covariate,
            static_covariate,
            None,  # <---- add this line for the sample weights
            future_target,
        )

Hi @dennisbader , I have just tested your solution and it actually works great! I also updated to 0.30.0 by the way and updated the TrainingDataset class. The only thing I noticed is that the EarlyStopping object completely ignored the model's output. I'm going to investigate and if it's the case I'll open an issue about that, as it is a different topic.

Thanks a lot!

Glad that it helped 🚀 I'm closing the issue in that case