Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Continuing training with `ckpt_path="last"` and MLFLowLogger fails in distributed setting

selflein opened this issue · comments

commented

Bug description

Summary

When continuing a training

  • using trainer.fit(..., ckpt_path="last") with DDP (#devices > 1)
  • MLFlowLogger (but also validated for AimLogger and might affect others. TensorboardLogger works fine though.)
  • ModelCheckpoint(dirpath=None, save_last=True, save_top_k=n) (n > 0)

the continuation of the training will fail with a cryptic error message (or get stuck).


High-level course of events that lead to the issue

What happens on a high-level is that in the continuation run:

  • ModelCheckpoint tries to recover the checkpoint directory from the logger when dirpath=None.
    elif ckpt_path == "last":
    candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints}
    for callback in self.trainer.checkpoint_callbacks:
    if isinstance(callback, ModelCheckpoint):
    candidates |= callback._find_last_checkpoints(self.trainer)
  • Some loggers (e.g., MLFlow) only expose the required attributes (e.g, version) for reconstructing the checkpoint path on rank 0. Logic for reconstructing the checkpoint path given a logger:
    if len(trainer.loggers) > 0:
    if trainer.loggers[0].save_dir is not None:
    save_dir = trainer.loggers[0].save_dir
    else:
    save_dir = trainer.default_root_dir
    name = trainer.loggers[0].name
    version = trainer.loggers[0].version
    version = version if isinstance(version, str) else f"version_{version}"
    ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
  • Last checkpoint is not found for non-rank 0 workers.
  • Best metric, i.e., ModelCheckpoint state is not recovered on non-rank 0 workers.
  • Rank 0 and non-rank 0 get out of sync in terms of what checkpoint is best.
  • Training errors or gets stuck due to ranks waiting on each other.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os
import sys

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss, sync_dist=True)

    def on_train_epoch_start(self) -> None:
        rank = os.environ["RANK"]
        print(f"RANK {rank}", self.trainer.checkpoint_callback.state_dict(), flush=True)

    def on_validation_epoch_end(self) -> None:
        # Just for demonstration purposes to exit after one epoch
        self.trainer.should_stop = True

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


@record
def run():
    run_id = sys.argv[1] if len(sys.argv) == 2 else None

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    ckpt_callback = ModelCheckpoint(
        monitor="valid_loss",
        mode="min",
        save_top_k=1,
        save_last=True,
    )
    logger = MLFlowLogger(save_dir="/outputs/mlflow", run_id=run_id)
    print("LOGGER VERSION:", logger.version)

    model = BoringModel()
    trainer = Trainer(
        callbacks=[ckpt_callback],
        logger=logger,
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=2,
        devices=2,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path="last")


if __name__ == "__main__":
    run()

Error messages and logs

Reproduction

Starting the run (no error here yet):

NCCL_DEBUG=WARN torchrun --nproc_per_node 2 lightning_restore_from_last_bug.py
bash-5.0# NCCL_DEBUG=WARN torchrun --nproc_per_node 2 lightning_restore_from_last_bug.py
[2024-06-02 06:52:48,558] torch.distributed.run: [WARNING]
[2024-06-02 06:52:48,558] torch.distributed.run: [WARNING] *****************************************
[2024-06-02 06:52:48,558] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2024-06-02 06:52:48,558] torch.distributed.run: [WARNING] *****************************************
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO] Starting elastic_operator with launch configs:
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   entrypoint       : lightning_restore_from_last_bug.py
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   min_nodes        : 1
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   max_nodes        : 1
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   nproc_per_node   : 2
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   run_id           : none
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   rdzv_backend     : static
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   rdzv_endpoint    : 127.0.0.1:29500
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   rdzv_configs     : {'rank': 0, 'timeout': 900}
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   max_restarts     : 0
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   monitor_interval : 5
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   log_dir          : None
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]   metrics_cfg      : {}
[2024-06-02 06:52:48,558] torch.distributed.launcher.api: [INFO]
[2024-06-02 06:52:48,559] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] log directory set to: /tmp/torchelastic_w2lymixt/none_6_8qvfee
[2024-06-02 06:52:48,559] torch.distributed.elastic.agent.server.api: [INFO] [default] starting workers for entrypoint: python
[2024-06-02 06:52:48,559] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous'ing worker group
[W socket.cpp:464] [c10d] The server socket cannot be initialized on [::]:29500 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous complete for workers. Result:
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   restart_count=0
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   master_addr=127.0.0.1
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   master_port=29500
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   group_rank=0
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   group_world_size=1
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   local_ranks=[0, 1]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   role_ranks=[0, 1]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   global_ranks=[0, 1]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   role_world_sizes=[2, 2]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]   global_world_sizes=[2, 2]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO]
[2024-06-02 06:52:48,567] torch.distributed.elastic.agent.server.api: [INFO] [default] Starting worker group
[2024-06-02 06:52:48,568] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
[2024-06-02 06:52:48,568] torch.distributed.elastic.multiprocessing: [INFO] Setting worker0 reply file to: /tmp/torchelastic_w2lymixt/none_6_8qvfee/attempt_0/0/error.json
[2024-06-02 06:52:48,568] torch.distributed.elastic.multiprocessing: [INFO] Setting worker1 reply file to: /tmp/torchelastic_w2lymixt/none_6_8qvfee/attempt_0/1/error.json
LOGGER VERSION: None
LOGGER VERSION: 08e70fdc56194db2bea9226298fa3006
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

NCCL version 2.19.3+cuda11.8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Training:   0%|                                                                                                                                                                                                                                                                                                                                       | 0/1 [00:00<?, ?it/s]{'monitor': 'valid_loss', 'best_model_score': None, 'best_model_path': '', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {}, 'kth_best_model_path': '', 'kth_value': tensor(inf), 'last_model_path': ''}
Epoch 0:   0%|                                                                                                                                                                                                                                                                                                                                        | 0/1 [00:00<?, ?it/s]RANK 0 {'monitor': 'valid_loss', 'best_model_score': None, 'best_model_path': '', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {}, 'kth_best_model_path': '', 'kth_value': tensor(inf), 'last_model_path': ''}
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.77it/s, v_num=3006]
[2024-06-02 06:52:53,570] torch.distributed.elastic.agent.server.api: [INFO] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
[2024-06-02 06:52:53,571] torch.distributed.elastic.agent.server.api: [INFO] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
[2024-06-02 06:52:53,571] torch.distributed.elastic.agent.server.api: [INFO] Done waiting for other agents. Elapsed: 0.00045299530029296875 seconds

Error on continuing the run providing run_id:

 NCCL_DEBUG=WARN torchrun --nproc_per_node 2 lightning_restore_from_last_bug.py 08e70fdc56194db2bea9226298fa3006
bash-5.0# NCCL_DEBUG=WARN torchrun --nproc_per_node 2 lightning_restore_from_last_bug.py 08e70fdc56194db2bea9226298fa3006
[2024-06-02 06:53:19,070] torch.distributed.run: [WARNING]
[2024-06-02 06:53:19,070] torch.distributed.run: [WARNING] *****************************************
[2024-06-02 06:53:19,070] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2024-06-02 06:53:19,070] torch.distributed.run: [WARNING] *****************************************
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO] Starting elastic_operator with launch configs:
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   entrypoint       : lightning_restore_from_last_bug.py
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   min_nodes        : 1
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   max_nodes        : 1
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   nproc_per_node   : 2
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   run_id           : none
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   rdzv_backend     : static
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   rdzv_endpoint    : 127.0.0.1:29500
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   rdzv_configs     : {'rank': 0, 'timeout': 900}
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   max_restarts     : 0
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   monitor_interval : 5
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   log_dir          : None
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]   metrics_cfg      : {}
[2024-06-02 06:53:19,070] torch.distributed.launcher.api: [INFO]
[2024-06-02 06:53:19,071] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] log directory set to: /tmp/torchelastic_9mvoyktm/none_hovgmn4x
[2024-06-02 06:53:19,071] torch.distributed.elastic.agent.server.api: [INFO] [default] starting workers for entrypoint: python
[2024-06-02 06:53:19,071] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous'ing worker group
[W socket.cpp:464] [c10d] The server socket cannot be initialized on [::]:29500 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous complete for workers. Result:
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   restart_count=0
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   master_addr=127.0.0.1
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   master_port=29500
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   group_rank=0
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   group_world_size=1
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   local_ranks=[0, 1]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   role_ranks=[0, 1]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   global_ranks=[0, 1]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   role_world_sizes=[2, 2]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]   global_world_sizes=[2, 2]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO]
[2024-06-02 06:53:19,079] torch.distributed.elastic.agent.server.api: [INFO] [default] Starting worker group
[2024-06-02 06:53:19,080] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
[2024-06-02 06:53:19,080] torch.distributed.elastic.multiprocessing: [INFO] Setting worker0 reply file to: /tmp/torchelastic_9mvoyktm/none_hovgmn4x/attempt_0/0/error.json
[2024-06-02 06:53:19,080] torch.distributed.elastic.multiprocessing: [INFO] Setting worker1 reply file to: /tmp/torchelastic_9mvoyktm/none_hovgmn4x/attempt_0/1/error.json
LOGGER VERSION: 08e70fdc56194db2bea9226298fa3006
LOGGER VERSION: 08e70fdc56194db2bea9226298fa3006
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[W socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:29500 (errno: 97 - Address family not supported by protocol).
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

NCCL version 2.19.3+cuda11.8
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints exists and is not empty.
Restoring states from the checkpoint path at /outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/last.ckpt
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Restored all states from the checkpoint at /outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/last.ckpt
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Training: |                                                                                                                                                                                                                                                                                                                                           | 0/? [00:00<?, ?it/s]{'monitor': 'valid_loss', 'best_model_score': None, 'best_model_path': '', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {}, 'kth_best_model_path': '', 'kth_value': tensor(inf), 'last_model_path': ''}
Epoch 1:   0%|                                                                                                                                                                                                                                                                                                                                        | 0/1 [00:00<?, ?it/s]RANK 0 {'monitor': 'valid_loss', 'best_model_score': tensor(-0.5909), 'best_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {'/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt': tensor(-0.5909)}, 'kth_best_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt', 'kth_value': tensor(-0.5909), 'last_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/last.ckpt'}
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.56it/s, v_num=3006]Traceback (most recent call last):
  File "/home/selflein/Developer/pgnn/lightning_restore_from_last_bug.py", line 92, in <module>
    run()
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/selflein/Developer/pgnn/lightning_restore_from_last_bug.py", line 88, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path="last")
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
    self.on_advance_end()
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 378, in on_advance_end
    call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 324, in on_train_epoch_end
    self._save_topk_checkpoint(trainer, monitor_candidates)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 384, in _save_topk_checkpoint
    self._save_monitor_checkpoint(trainer, monitor_candidates)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 704, in _save_monitor_checkpoint
    self._update_best_and_save(current, trainer, monitor_candidates)
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 733, in _update_best_and_save
    filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 662, in _get_metric_interpolated_filepath_name
    while self.file_exists(filepath, trainer) and filepath != del_filepath:
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 775, in file_exists
    return trainer.strategy.broadcast(exists)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/strategies/ddp.py", line 307, in broadcast
    torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2438, in broadcast_object_list
    object_list[i] = _tensor_to_object(obj_view, obj_size)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2123, in _tensor_to_object
    return _unpickler(io.BytesIO(buf)).load()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
EOFError: Ran out of input
[2024-06-02 06:53:29,083] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1165018 closing signal SIGTERM
[2024-06-02 06:53:59,084] torch.distributed.elastic.multiprocessing.api: [WARNING] Unable to shutdown process 1165018 via 15, forcefully exiting via 9
[2024-06-02 06:53:59,153] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 1165019) of binary: /opt/micromamba/bin/python
[2024-06-02 06:53:59,165] torch.distributed.elastic.multiprocessing.errors.error_handler: [ERROR] no error file defined for parent, to copy child error file (/tmp/torchelastic_9mvoyktm/none_hovgmn4x/attempt_0/1/error.json)
Traceback (most recent call last):
  File "/opt/micromamba/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.2.2', 'console_scripts', 'torchrun')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
lightning_restore_from_last_bug.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-06-02_06:53:22
  host      : rno1-m03-g03-dgx1-015.draco-rno.nvidia.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 1165019)
  error_file: /tmp/torchelastic_9mvoyktm/none_hovgmn4x/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/home/selflein/Developer/pgnn/lightning_restore_from_last_bug.py", line 88, in run
      trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path="last")
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
      call._call_and_handle_interrupt(
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
      return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
      return function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
      self._run(model, ckpt_path=ckpt_path)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
      results = self._run_stage()
                ^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
      self.fit_loop.run()
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
      self.on_advance_end()
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 378, in on_advance_end
      call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
      fn(trainer, trainer.lightning_module, *args, **kwargs)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 324, in on_train_epoch_end
      self._save_topk_checkpoint(trainer, monitor_candidates)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 384, in _save_topk_checkpoint
      self._save_monitor_checkpoint(trainer, monitor_candidates)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 704, in _save_monitor_checkpoint
      self._update_best_and_save(current, trainer, monitor_candidates)
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 733, in _update_best_and_save
      filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 662, in _get_metric_interpolated_filepath_name
      while self.file_exists(filepath, trainer) and filepath != del_filepath:
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 775, in file_exists
      return trainer.strategy.broadcast(exists)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/lightning/pytorch/strategies/ddp.py", line 307, in broadcast
      torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
    File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
      return func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2438, in broadcast_object_list
      object_list[i] = _tensor_to_object(obj_view, obj_size)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/micromamba/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2123, in _tensor_to_object
      return _unpickler(io.BytesIO(buf)).load()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  EOFError: Ran out of input

============================================================

Key issue

The key things to notice is that in the continuation of the run the ModelCheckpoint state is restored inconsistently across ranks (c.f. printing it on_train_epoch_start in the sample code):

Rank 0

RANK 0 {'monitor': 'valid_loss', 'best_model_score': tensor(-0.5909), 'best_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {'/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt': tensor(-0.5909)}, 'kth_best_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/epoch=0-step=1.ckpt', 'kth_value': tensor(-0.5909), 'last_model_path': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints/last.ckpt'}

Rank 1:

{'monitor': 'valid_loss', 'best_model_score': None, 'best_model_path': '', 'current_score': None, 'dirpath': '/outputs/mlflow/720298163187032648/08e70fdc56194db2bea9226298fa3006/checkpoints', 'best_k_models': {}, 'kth_best_model_path': '', 'kth_value': tensor(inf), 'last_model_path': ''}

As a result state is inconsistent in ModelCheckpoint when figuring out which checkpoint to save and errors out.

Environment

Current environment
  • CUDA:
    • GPU:
      • Tesla V100-SXM2-32GB-LS
      • Tesla V100-SXM2-32GB-LS
    • available: True
    • version: 11.8
  • Lightning:
    • lightning: 2.2.5
    • lightning-utilities: 0.11.2
    • pytorch-lightning: 2.2.5
    • torch: 2.2.2
    • torch-cluster: 1.6.3+pt22cu118
    • torch-geometric: 2.5.3
    • torch-scatter: 2.1.2+pt22cu118
    • torch-sparse: 0.6.18+pt22cu118
    • torch-spline-conv: 1.2.2+pt22cu118
    • torchmetrics: 1.4.0.post0
    • torchvision: 0.17.2
  • Packages:
    • aim: 3.19.3
    • aim-ui: 3.19.3
    • aimrecords: 0.0.7
    • aimrocks: 0.4.0
    • aiofiles: 23.2.1
    • aiohttp: 3.9.5
    • aiosignal: 1.3.1
    • albumentations: 1.4.7
    • alembic: 1.13.1
    • annotated-types: 0.7.0
    • antlr4-python3-runtime: 4.9.3
    • anyio: 4.4.0
    • asttokens: 2.4.1
    • attrs: 23.2.0
    • base58: 2.0.1
    • boto3: 1.34.113
    • botocore: 1.34.113
    • braceexpand: 0.1.7
    • brotli: 1.1.0
    • cachetools: 5.3.3
    • certifi: 2024.2.2
    • cffi: 1.16.0
    • charset-normalizer: 3.3.2
    • click: 8.1.7
    • cloudpickle: 3.0.0
    • comm: 0.2.2
    • cryptography: 42.0.7
    • debugpy: 1.8.1
    • decorator: 5.1.1
    • deprecated: 1.2.14
    • dnspython: 2.6.1
    • einops: 0.8.0
    • email-validator: 2.1.1
    • entrypoints: 0.4
    • executing: 2.0.1
    • fastapi: 0.111.0
    • fastapi-cli: 0.0.4
    • filelock: 3.14.0
    • frozenlist: 1.4.1
    • fsspec: 2024.5.0
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • gmpy2: 2.1.5
    • greenlet: 3.0.3
    • h11: 0.14.0
    • httpcore: 1.0.5
    • httptools: 0.6.1
    • httpx: 0.27.0
    • huggingface-hub: 0.23.2
    • hydra-core: 1.3.2
    • idna: 3.7
    • imageio: 2.34.1
    • importlib-metadata: 7.1.0
    • ipykernel: 6.29.4
    • ipython: 8.24.0
    • jedi: 0.19.1
    • jinja2: 3.1.4
    • jmespath: 1.0.1
    • joblib: 1.4.2
    • jupyter-client: 8.6.2
    • jupyter-core: 5.7.2
    • lazy-loader: 0.4
    • lightning: 2.2.5
    • lightning-utilities: 0.11.2
    • mako: 1.3.5
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • matplotlib-inline: 0.1.7
    • mdurl: 0.1.2
    • mlflow-skinny: 2.13.1
    • mpmath: 1.3.0
    • multidict: 6.0.5
    • nest-asyncio: 1.6.0
    • networkx: 3.3
    • numpy: 1.26.4
    • omegaconf: 2.3.0
    • opencv-python-headless: 4.9.0.80
    • opentelemetry-api: 1.25.0
    • opentelemetry-sdk: 1.25.0
    • opentelemetry-semantic-conventions: 0.46b0
    • orjson: 3.10.3
    • packaging: 24.0
    • parso: 0.8.4
    • pexpect: 4.9.0
    • pillow: 9.4.0
    • pip: 24.0
    • platformdirs: 4.2.2
    • prompt-toolkit: 3.0.43
    • protobuf: 4.25.3
    • psutil: 5.9.8
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pycparser: 2.22
    • pydantic: 2.7.1
    • pydantic-core: 2.18.2
    • pyg-lib: 0.4.0+pt22cu118
    • pygments: 2.18.0
    • pyparsing: 3.1.2
    • pysocks: 1.7.1
    • python-dateutil: 2.9.0.post0
    • python-dotenv: 1.0.1
    • python-multipart: 0.0.9
    • pytorch-lightning: 2.2.5
    • pytz: 2024.1
    • pyyaml: 6.0.1
    • pyzmq: 26.0.3
    • requests: 2.32.2
    • restrictedpython: 7.1
    • rich: 13.7.1
    • ruff: 0.4.5
    • s3transfer: 0.10.1
    • safetensors: 0.4.3
    • scikit-image: 0.23.2
    • scikit-learn: 1.5.0
    • scipy: 1.13.1
    • setuptools: 70.0.0
    • shellingham: 1.5.4
    • six: 1.16.0
    • smmap: 5.0.1
    • sniffio: 1.3.1
    • sqlalchemy: 2.0.30
    • sqlparse: 0.5.0
    • stack-data: 0.6.3
    • starlette: 0.37.2
    • sympy: 1.12
    • threadpoolctl: 3.5.0
    • tifffile: 2024.5.22
    • torch: 2.2.2
    • torch-cluster: 1.6.3+pt22cu118
    • torch-geometric: 2.5.3
    • torch-scatter: 2.1.2+pt22cu118
    • torch-sparse: 0.6.18+pt22cu118
    • torch-spline-conv: 1.2.2+pt22cu118
    • torchmetrics: 1.4.0.post0
    • torchvision: 0.17.2
    • tornado: 6.4
    • tqdm: 4.66.4
    • traitlets: 5.14.3
    • transforms3d: 0.4.1
    • triton: 2.2.0
    • typer: 0.12.3
    • typing-extensions: 4.11.0
    • ujson: 5.10.0
    • urllib3: 2.2.1
    • uvicorn: 0.29.0
    • uvloop: 0.19.0
    • watchfiles: 0.22.0
    • wcwidth: 0.2.13
    • webdataset: 0.2.86
    • websockets: 12.0
    • wheel: 0.43.0
    • wrapt: 1.16.0
    • yarl: 1.9.4
    • zipp: 3.19.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.9
    • release: 5.4.0-80-generic
    • version: #90-Ubuntu SMP Fri Jul 9 22:49:44 UTC 2021

More info

No response