pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.

Home Page:https://pytorch.org/examples

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The GPU load is unbalanced

lianchengmingjue opened this issue · comments

snapshot = torch.load(snapshot_path)

When I run the code and resume from a existed .pt file. The memory usage of GPU0 is significantly higher than other GPUs.
It can be solved by adding a parameter "map_location".
snapshot = torch.load(snapshot_path, map_location=torch.device('cuda', int(os.environ["LOCAL_RANK"])))

My Environment

cudatoolkit 10.2
pytorch 12.1

@lianchengmingjue good catch! By default, torch.load() first loads the snapshot to CPU then moves to the device it was saved from(I guess it's GPU0). In this case, all ranks load the snapshot to GPU0. We should always use "map_location" in torch.load() to load files saved in other environment. Because it might be saved in GPUx which doesn't exist in your host and cause a failure during loading. Please feel free to send a PR for the fix.
cc: @suraj813