hpcaitech / ColossalAI-Examples

Examples of training models with hybrid parallelism using ColossalAI

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ZeRO without using shard_param

powermano opened this issue Β· comments

πŸ› Describe the bug

When i use ZeRO without shard_params, it occurs the following problems

Traceback (most recent call last):
  File "train.py", line 175, in <module>
    main()
  File "train.py", line 39, in main
    with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
  File "/usr/local/Python-3.8.6/lib/python3.8/site-packages/colossalai/zero/init_ctx/init_context.py", line 75, in __init__
    self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
  File "/usr/local/Python-3.8.6/lib/python3.8/site-packages/colossalai/zero/init_ctx/init_context.py", line 37, in __init__
    assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
AttributeError: 'int' object has no attribute 'type'

My init code is:

def main():
    parser = colossalai.get_default_parser()
    parser.add_argument('--use_trainer', action='store_true', help='whether to use trainer')
    args = parser.parse_args()

    colossalai.launch_from_torch(config='./config.py')

    logger = get_dist_logger()

    rank = int(os.environ['RANK'])
    # build resnet
    use_zero3 = hasattr(gpc.config, 'zero')
    if use_zero3:
        shard_strategy = TensorShardStrategy()
        with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
            model = resnet34(num_classes=10)
    else:
        model = resnet34(num_classes=10)

my config is

from colossalai.amp import AMP_TYPE
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import HybridAdam

zero = dict(
    model_config=dict(
        tensor_placement_policy='cuda',
        shard_strategy=TensorShardStrategy(),
        reuse_fp16_shard=False
    ),
    optimizer_config=dict()
)

optimizer = dict(
    type=HybridAdam,
    lr=0.001,
    # weight_decay=1e-2,
)

BATCH_SIZE = 64
NUM_EPOCHS = 20
LOGGING_FREQUNCE = 20
OUTPUT = './'

gradient_clipping = 5.0

Environment

pip install colossalai==0.1.5+torch1.10cu11.1 -f https://release.colossalai.org

ubuntu 18.04

If I modified the code as following, it actually worked.

 rank = int(os.environ['RANK'])
    # build resnet
  use_zero3 = hasattr(gpc.config, 'zero')
  if use_zero3:
      shard_strategy = TensorShardStrategy()
      
      # with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
      #     model = resnet34(num_classes=10)
      with ZeroInitContext(target_device=torch.device('cuda', rank), shard_strategy=shard_strategy, shard_param=False):
          model = resnet34(num_classes=10)

@fastalgo I do not know how to save the ZeRO model params. When using the save_checkpoint API , the saved file is pretty small.

I tested the ZeRO using private dataset and ir18(which a lit bit different with origin resnet18). The following tabel is the specific results.
When i used pytorch origin amp, the gpu memory is much smaller than colossai, why?
the config is

from colossalai.amp import AMP_TYPE
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import HybridAdam

fp16 = dict(
    mode=AMP_TYPE.TORCH,
)

optimizer = dict(
    type=HybridAdam,
    lr=0.001,
    # weight_decay=1e-2,
)
model dataset machine batch gradient accmulate size ZeRO speed GPU memory OPT tensor_placement_policy    
ir18 private dataset 1 64 1 no ZeRO 24%|β–ˆβ–ˆβ–       | 2089/8549 [02:51<08:39, 12.43it/s] 8703M HybridAdam   single machine + Engine  
ir18 private dataset 1 64 1 no ZeRO 19%|β–ˆβ–Š        | 1599/8549 [02:24<10:21, 11.17it/s] 5769M HybridAdam   single machine  + wo Engine + pytorch origin fp16  
ir18 private dataset 2 64 2 no ZeRO 37%|β–ˆβ–ˆβ–ˆβ–‹      | 1598/4274 [02:32<04:14, 10.50it/s] 9011M SGD   common data paralle  
ir18 private dataset 2 64 1 ZeRO + No shard params 14%|β–ˆβ–        | 606/4275 [01:25<08:27,  7.23it/s] 9141M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 13%|β–ˆβ–Ž        | 571/4275 [01:32<10:32,  5.85it/s] 9073M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 5%|β–Œ         | 217/4275 [01:37<29:16,  2.31it/s] 6819M HybridAdam cpu    

the code without using Engine is shown as following:

model = ...
optimizer = ...
criterion = ...
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=1000)
global_step = 0
optimizer.zero_grad()
for epoch in range(gpc.config.NUM_EPOCHS):
    model.train()
    for idx, (img, label) in enumerate(train_dl):
        img = img.cuda()
        label = label.cuda()
        output, _ = model(img, label)
        train_loss = criterion(output, label)
        amp.scale(train_loss).backward()
        amp.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        amp.step(optimizer)
        amp.update()   
        optimizer.zero_grad()

The difference between my origin pytorch implementation and colossai is convert_to_amp API which using TorchAMPModel to decorate the origin model.
I have tested three different cases:

1 using torch.cuda.amp.autocast(True) inside model forward function:

class model(nn.Module):
    def __init__():
        ....
    def forward(self, x, label):
        with torch.cuda.amp.autocast(True):
              .....
              .....
        return x

2 using @torch.cuda.amp.autocast()

class model(nn.Module):
    def __init__():
        ....
    @torch.cuda.amp.autocast()
    def forward(self, x, label):
              .....
              .....
        return x

3 using TorchAMPModel

class model(nn.Module):
    def __init__():
        ....
    def forward(self, x, label):
              .....
              .....
        return x

model = model()
model = TorchAMPModel(model)

The first two are normal and only need 5769M GPU memory, but the third one needs 8703M GPU memory

@feifeibear Can you help to verify the above problem? Thanks.