damitkwr / ESRNN-GPU

PyTorch GPU implementation of the ES-RNN model for time series forecasting

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

max_loss is not being updated

junhyk opened this issue · comments

Hi,

It seems that max_loss in function train_epochs() at esrnn/trainer.py is not being updated appropriately

def train_epochs(self):
        max_loss = 1e8
        start_time = time.time()
        for e in range(self.max_epochs):
            self.scheduler.step()
            epoch_loss = self.train()
            if epoch_loss < max_loss:
                self.save()
            epoch_val_loss = self.val()
            if e == 0:
                file_path = os.path.join(self.csv_save_path, 'validation_losses.csv')
                with open(file_path, 'w') as f:
                    f.write('epoch,training_loss,validation_loss\n')
            with open(file_path, 'a') as f:
                f.write(','.join([str(e), str(epoch_loss), str(epoch_val_loss)]) + '\n')
        print('Total Training Mins: %5.2f' % ((time.time()-start_time)/60))

Thanks!

Right, this code is not formal