Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Returning num_replicas=world_size when using distributed sampler in ddp

arjunagarwal899 opened this issue · comments

Bug description

The default LightningEnvironment assumes that every node in a multi-node environment has equal number of GPUs i.e. each node assumes that the world size is equal to number of nodes multiplied by the number of (active) devices on that node.

However, implementing one's own environment can bypass this limitation (example attached below). While the processes get registered successfully, the attribute num_replicas that is provided to the DistributedSampler class is still initialized independently of the environment, which leads to an error of having ranks outside the scope of the world size.

Fix: Use num_replicas=self.world_size() instead of estimating the world size again.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Config and custom Environment:

# Code that sets some config parameters as follows
config.nodes = [
    ("NODE_NAME", ["NODE_IP_ADDRESS", "NUM_GPUS: int"]),  # Index 0 is master node
    ...
]
config.num_nodes = len(config.nodes)

# Set some global variables
MASTER_PORT = 10051 # Set port here
MASTER_ADDR = config.nodes[0][1][0] # Set address here
WORLD_SIZE = sum([node_info[1] for _, node_info in config.nodes]) # Set world size here

# Set config devices, NODE_RANK, and global rank starting point
NODE_RANK = ...
GLOBAL_RANK_OFFSET = 0
for i, (node, node_info) in enumerate(config.nodes):
    if node == socket.gethostname():
        config.devices = node_info[1]
        NODE_RANK = i
        break
    GLOBAL_RANK_OFFSET += node_info[1]

# Set environment variables
os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
os.environ["NODE_RANK"] = str(NODE_RANK)


class MyClusterEnvironment(LightningEnvironment):
    def set_world_size(self, size: int):
        # Here, size = num_nodes * len(devices)  which does not work for heterogenous clusters
        self._world_size = WORLD_SIZE

    def set_global_rank(self, rank: int):
        # Here, global_rank = node_rank * len(devices) + local_rank  which does not work for heterogenous clusters
        global_rank = GLOBAL_RANK_OFFSET + self.local_rank()
        self._global_rank = global_rank

config.cluster_environment = MyClusterEnvironment()

Trainer:

trainer = L.Trainer(
    num_nodes=config.num_nodes,
    devices=config.devices,
    plugins=[config.cluster_environment],
    ... # Other arguments
)

Run on:

  • Node 0: 4xA100 80GB PCIe GPUs
  • Node 1: 2xA100 80GB PCIe GPUs

Error messages and logs

Error on node 1:

╭────────────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────────────╮
│ /home/users/arjun.agarwal/projects/mock_training/distributed.py:99 in <module>                                                 │
│                                                                                                                                │
│    96 │   │   plugins=[config.cluster_environment],                                                                            │
│    97 │   )                                                                                                                    │
│    98 │                                                                                                                        │
│ ❱  99 │   trainer.fit(model, dm)                                                                                               │
│   100                                                                                                                          │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544 in fit               │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43 in                       │
│ _call_and_handle_interrupt                                                                                                     │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py:1 │
│ 05 in launch                                                                                                                   │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580 in _fit_impl         │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987 in _run              │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1031 in _run_stage       │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1060 in                  │
│ _run_sanity_check                                                                                                              │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py:182 in _decorator        │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:110 in run         │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:180 in setup_data  │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:501 in │
│ _process_dataloader                                                                                                            │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:206 in │
│ _prepare_dataloader                                                                                                            │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:217 in │
│ _resolve_sampler                                                                                                               │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:258 in │
│ _get_distributed_sampler                                                                                                       │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/torch/utils/data/distributed.py:74 in __init__                │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Invalid rank 4, rank should be in the interval [0, 3]

num_replicas gets set to 4 here as num_nodes=2 and num_processes=2. However world size is 6 as defined in the environment.

Environment

Current environment
* Lightning:
	- efficientnet-pytorch: 0.7.1
	- lightning:         2.2.5
	- lightning-cloud:   0.5.61
	- lightning-utilities: 0.10.0
	- pytorch-lightning: 2.1.2
	- pytorchvideo:      0.1.5
	- torch:             2.2.2
	- torchaudio:        2.2.2
	- torchmetrics:      1.2.1
	- torchsummary:      1.5.1
	- torchvision:       0.17.2

More info

The issue can be fixed by replacing ddp.py:L137

    @property
    @override
    def distributed_sampler_kwargs(self) -> Dict[str, Any]:
        return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}

with

    @property
    @override
    def distributed_sampler_kwargs(self) -> Dict[str, Any]:
        return {"num_replicas": self.world_size, "rank": self.global_rank}