Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MLFlowLogger fails when logging hyperparameters as Trainer already does automatically

CristoJV opened this issue · comments

Bug description

I encountered an MLFlow exception when logging my model hyperparameters at hooks on_fit_start or on_train_start:

def on_fit_start(self):
    if self.trainer.is_global_zero:
        hparams = copy.deepcopy(self.hparams)
        hparams = self.clean_hparams(hparams)
        if isinstance(self.logger, MLFlowLogger):
            self.logger.log_hyperparams(hparams)
        elif isinstance(self.logger, TensorBoardLogger):
            self.logger.log_hyperparams(hparams)

The on_fit_start hook successfully logs hyperparameters after cleaning them (verified with the MLFlow Client). However, immediately after, the following exception occurs:

mlflow.exceptions.RestException:
INVALID_PARAMETER_VALUE: Changing param values is not allowed. Param with key='loss_params/alpha' was already logged with value='[0.15253213047981262, 0.170266255736351, 0.15075302124023438, 0.28234609961509705, 0.2441024
9292850494]' for run ID='f74bdff19b6c4e9aa3abf5fd054f9c1c'. Attempted logging new value 'tensor([0.1525, 0.1703, 0.1508, 0.2823, 0.2441])'.

This exception is raised because MLFlow does not allow changing parameter values once they are logged. This led me to investigate if hyperparameters were being logged twice. As I found out by checking the stack trace, the trainer internally calls log_hyperparameters within _run, causing the hyperparameters to be logged twice:

face_sequence.py 405 main                                                                                                                                                                                                    
trainer.fit(                                                                                                                                                                                                                 
                                                                                                                                                                                                                             
trainer.py 544 fit                                                                                                                                                                                                           
call._call_and_handle_interrupt(                                                                                                                                                                                             
                                                                                                                                                                                                                             
call.py 43 _call_and_handle_interrupt                                                                                                                                                                                        
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)                                                                                                                                        
                                                                                                                                                                                                                             
subprocess_script.py 102 launch                                                                                                                                                                                              
return function(*args, **kwargs)                                                                                                                                                                                             
                                                                                                                                                                                                                             
trainer.py 580 _fit_impl                                                                                                                                                                                                     
self._run(model, ckpt_path=ckpt_path)

trainer.py 972 _run
_log_hyperparams(self)

utilities.py 93 _log_hyperparams
logger.log_hyperparams(hparams_initial)

rank_zero.py 42 wrapped_fn
return fn(*args, **kwargs)

mlflow.py 233 log_hyperparams
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])

client.py 1093 log_batch
return self._tracking_client.log_batch(

client.py 444 log_batch
self.store.log_batch(

rest_store.py 323 log_batch
self._call_endpoint(LogBatch, req_body)

rest_store.py 59 _call_endpoint

rest_store.py 59 _call_endpoint
return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)

rest_utils.py 219 call_endpoint
response = verify_rest_response(response, endpoint)

rest_utils.py 151 verify_rest_response
raise RestException(json.loads(response.text))

Here is the extracted code from train.py (reduced version):

def _run(
    self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
    if self.state.fn == TrainerFn.FITTING:
        min_epochs, max_epochs = _parse_loop_limits(
            self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
        )
        self.fit_loop.min_epochs = min_epochs
        self.fit_loop.max_epochs = max_epochs

    _log_hyperparams(self)

    if self.strategy.restore_checkpoint_after_setup:
        log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
        self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

    return results

And the _log_hyperparams function:

def _log_hyperparams(trainer: "pl.Trainer") -> None:
    if not trainer.loggers:
        return

    pl_module = trainer.lightning_module
    datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False

    hparams_initial = None
    if pl_module._log_hyperparams and datamodule_log_hyperparams:
        datamodule_hparams = trainer.datamodule.hparams_initial
        lightning_hparams = pl_module.hparams_initial
        inconsistent_keys = []
        for key in lightning_hparams.keys() & datamodule_hparams.keys():
            lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
            if (
                type(lm_val) != type(dm_val)
                or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val))
                or lm_val != dm_val
            ):
                inconsistent_keys.append(key)
        if inconsistent_keys:
            raise RuntimeError(
                f"Error while merging hparams: the keys {inconsistent_keys} are present "
                "in both the LightningModule's and LightningDataModule's hparams "
                "but have different values."
            )
        hparams_initial = {**lightning_hparams, **datamodule_hparams}
    elif pl_module._log_hyperparams:
        hparams_initial = pl_module.hparams_initial
    elif datamodule_log_hyperparams:
        hparams_initial = trainer.datamodule.hparams_initial

    for logger in trainer.loggers:
        if hparams_initial is not None:
            logger.log_hyperparams(hparams_initial)
        logger.log_graph(pl_module)
        logger.save()

Is there any workaround to avoid the trainer logging the hyperparameters forcefully?

pytorch lightning version: Version: 2.1.4

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response