ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Logging step being called twice?

cmvcordova opened this issue · comments

Hey guys,

I made a project based off of this template (Amazing work!). For past versions it used to run flawlessly. However, lately I've been facing the following issue:

lightning.fabric.utilities.exceptions.MisconfigurationException: You called `self.log(val/loss, ...)` twice in `validation_step` with different arguments. This is not allowed.

My validation step and on_validation_epoch_end look like this:

def validation_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)
        # update and log metrics
        self.val_loss(loss)
        self.val_metric(preds, targets)
        self.log("val/loss", self.val_loss,
                 on_step=False, 
                 on_epoch=True, 
                 prog_bar=True)
        self.log(f"val/{self.metric_name}", self.val_metric, 
                 on_step=False, 
                 on_epoch=True, 
                 prog_bar=True)

    def on_validation_epoch_end(self):
        metric = self.val_metric.compute()  # get current val metric
        self.val_metric_best(metric)  # update best so far val metric
        # log `val_metric_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log(f"val/{self.metric_name}_best", self.val_metric_best.compute(), 
                 on_step=False,
                 on_epoch=True, 
                 sync_dist=True, 
                 prog_bar=True
            )

Where my modified self.val_metric is just assigned dynamically based on the task assigned to the module, in the following fashion:

        if task == "regression":
            if isinstance(criterion, RMSELoss):
                self.metric_name = 'rmse'
                self.train_metric = MeanSquaredError(squared=False)
                self.val_metric = MeanSquaredError(squared=False)
                self.test_metric = MeanSquaredError(squared=False)

Keeping the rest of the script mostly the same as its source, at, src/models/mnist_module.py

I was not able to replicate the issue with the template as-is. However, would you guys be able to shine a little intuition as to what's going on with self.log calls? I'm not sure why I'm getting the "twice" call. I tried turning off the validation logging entirely and this was replicated by the training_step function. validation_step seems to be executed once within the self.log calls, but fails after one iteration.

I'd appreciate any and all pointers that could direct me towards solving the issue. In case that's not possible, is there a way to refactor the logging within the module to something more rudimentary that you guys would encourage?

Thank you for your time.

Had to overhaul a few configs, error as described is not accurate.