Returning num_replicas=world_size when using distributed sampler in ddp
arjunagarwal899 opened this issue · comments
Bug description
The default LightningEnvironment
assumes that every node in a multi-node environment has equal number of GPUs i.e. each node assumes that the world size is equal to number of nodes multiplied by the number of (active) devices on that node.
However, implementing one's own environment can bypass this limitation (example attached below). While the processes get registered successfully, the attribute num_replicas
that is provided to the DistributedSampler
class is still initialized independently of the environment, which leads to an error of having ranks outside the scope of the world size.
Fix: Use num_replicas=self.world_size()
instead of estimating the world size again.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Config and custom Environment:
# Code that sets some config parameters as follows
config.nodes = [
("NODE_NAME", ["NODE_IP_ADDRESS", "NUM_GPUS: int"]), # Index 0 is master node
...
]
config.num_nodes = len(config.nodes)
# Set some global variables
MASTER_PORT = 10051 # Set port here
MASTER_ADDR = config.nodes[0][1][0] # Set address here
WORLD_SIZE = sum([node_info[1] for _, node_info in config.nodes]) # Set world size here
# Set config devices, NODE_RANK, and global rank starting point
NODE_RANK = ...
GLOBAL_RANK_OFFSET = 0
for i, (node, node_info) in enumerate(config.nodes):
if node == socket.gethostname():
config.devices = node_info[1]
NODE_RANK = i
break
GLOBAL_RANK_OFFSET += node_info[1]
# Set environment variables
os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
os.environ["NODE_RANK"] = str(NODE_RANK)
class MyClusterEnvironment(LightningEnvironment):
def set_world_size(self, size: int):
# Here, size = num_nodes * len(devices) which does not work for heterogenous clusters
self._world_size = WORLD_SIZE
def set_global_rank(self, rank: int):
# Here, global_rank = node_rank * len(devices) + local_rank which does not work for heterogenous clusters
global_rank = GLOBAL_RANK_OFFSET + self.local_rank()
self._global_rank = global_rank
config.cluster_environment = MyClusterEnvironment()
Trainer:
trainer = L.Trainer(
num_nodes=config.num_nodes,
devices=config.devices,
plugins=[config.cluster_environment],
... # Other arguments
)
Run on:
- Node 0: 4xA100 80GB PCIe GPUs
- Node 1: 2xA100 80GB PCIe GPUs
Error messages and logs
Error on node 1:
╭────────────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────────────╮
│ /home/users/arjun.agarwal/projects/mock_training/distributed.py:99 in <module> │
│ │
│ 96 │ │ plugins=[config.cluster_environment], │
│ 97 │ ) │
│ 98 │ │
│ ❱ 99 │ trainer.fit(model, dm) │
│ 100 │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544 in fit │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43 in │
│ _call_and_handle_interrupt │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py:1 │
│ 05 in launch │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580 in _fit_impl │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987 in _run │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1031 in _run_stage │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1060 in │
│ _run_sanity_check │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py:182 in _decorator │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:110 in run │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:180 in setup_data │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:501 in │
│ _process_dataloader │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:206 in │
│ _prepare_dataloader │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:217 in │
│ _resolve_sampler │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:258 in │
│ _get_distributed_sampler │
│ │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/torch/utils/data/distributed.py:74 in __init__ │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Invalid rank 4, rank should be in the interval [0, 3]
num_replicas
gets set to 4 here as num_nodes=2
and num_processes=2
. However world size is 6 as defined in the environment.
Environment
Current environment
* Lightning:
- efficientnet-pytorch: 0.7.1
- lightning: 2.2.5
- lightning-cloud: 0.5.61
- lightning-utilities: 0.10.0
- pytorch-lightning: 2.1.2
- pytorchvideo: 0.1.5
- torch: 2.2.2
- torchaudio: 2.2.2
- torchmetrics: 1.2.1
- torchsummary: 1.5.1
- torchvision: 0.17.2
More info
The issue can be fixed by replacing ddp.py:L137
@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}
with
@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {"num_replicas": self.world_size, "rank": self.global_rank}