MLFlowLogger fails when logging hyperparameters as Trainer already does automatically
CristoJV opened this issue · comments
Bug description
I encountered an MLFlow exception when logging my model hyperparameters at hooks on_fit_start
or on_train_start
:
def on_fit_start(self):
if self.trainer.is_global_zero:
hparams = copy.deepcopy(self.hparams)
hparams = self.clean_hparams(hparams)
if isinstance(self.logger, MLFlowLogger):
self.logger.log_hyperparams(hparams)
elif isinstance(self.logger, TensorBoardLogger):
self.logger.log_hyperparams(hparams)
The on_fit_start
hook successfully logs hyperparameters after cleaning them (verified with the MLFlow Client). However, immediately after, the following exception occurs:
mlflow.exceptions.RestException:
INVALID_PARAMETER_VALUE: Changing param values is not allowed. Param with key='loss_params/alpha' was already logged with value='[0.15253213047981262, 0.170266255736351, 0.15075302124023438, 0.28234609961509705, 0.2441024
9292850494]' for run ID='f74bdff19b6c4e9aa3abf5fd054f9c1c'. Attempted logging new value 'tensor([0.1525, 0.1703, 0.1508, 0.2823, 0.2441])'.
This exception is raised because MLFlow does not allow changing parameter values once they are logged. This led me to investigate if hyperparameters were being logged twice. As I found out by checking the stack trace, the trainer internally calls log_hyperparameters
within _run
, causing the hyperparameters to be logged twice:
face_sequence.py 405 main
trainer.fit(
trainer.py 544 fit
call._call_and_handle_interrupt(
call.py 43 _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
subprocess_script.py 102 launch
return function(*args, **kwargs)
trainer.py 580 _fit_impl
self._run(model, ckpt_path=ckpt_path)
trainer.py 972 _run
_log_hyperparams(self)
utilities.py 93 _log_hyperparams
logger.log_hyperparams(hparams_initial)
rank_zero.py 42 wrapped_fn
return fn(*args, **kwargs)
mlflow.py 233 log_hyperparams
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
client.py 1093 log_batch
return self._tracking_client.log_batch(
client.py 444 log_batch
self.store.log_batch(
rest_store.py 323 log_batch
self._call_endpoint(LogBatch, req_body)
rest_store.py 59 _call_endpoint
rest_store.py 59 _call_endpoint
return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
rest_utils.py 219 call_endpoint
response = verify_rest_response(response, endpoint)
rest_utils.py 151 verify_rest_response
raise RestException(json.loads(response.text))
Here is the extracted code from train.py
(reduced version):
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
)
self.fit_loop.min_epochs = min_epochs
self.fit_loop.max_epochs = max_epochs
_log_hyperparams(self)
if self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
return results
And the _log_hyperparams
function:
def _log_hyperparams(trainer: "pl.Trainer") -> None:
if not trainer.loggers:
return
pl_module = trainer.lightning_module
datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False
hparams_initial = None
if pl_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = trainer.datamodule.hparams_initial
lightning_hparams = pl_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if (
type(lm_val) != type(dm_val)
or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val))
or lm_val != dm_val
):
inconsistent_keys.append(key)
if inconsistent_keys:
raise RuntimeError(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif pl_module._log_hyperparams:
hparams_initial = pl_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = trainer.datamodule.hparams_initial
for logger in trainer.loggers:
if hparams_initial is not None:
logger.log_hyperparams(hparams_initial)
logger.log_graph(pl_module)
logger.save()
Is there any workaround to avoid the trainer logging the hyperparameters forcefully?
pytorch lightning version: Version: 2.1.4
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response