BaguaSys / bagua

Bagua Speeds up PyTorch

Home Page:https://tutorials.baguasys.com/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Why does FusedOptimizer has a huge impact on model precision?

ProHuper opened this issue · comments

I wrapped my custom optimizer with FusedOptimizer and the precision was way worse than that without FusedOptimizer. I think FusedOptimizer shouldn't be affecting the model precision. Or is there something wrong with my custom optimizer?

Here is the optimizer I use:

https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py

commented

Thanks for opening the issue. FusedOptimizer is expected to give the same result as the original one for any torch optimizer. We'll investigate this case.

I reproduced this problem on a very simply example, fixed the model params and input, and got the same result. When using Lamb optimizer, the param update result in each step is different compared with that without FusedOptimizer. When using Adam optimizer, the param update result is the same. So I think it's probably related with the lamb optimizer.

import torch
from torch.nn.modules.loss import CrossEntropyLoss
from utils.LAMB_pt import LAMB
from bagua.torch_api.contrib import FusedOptimizer
import torch.nn as nn
import torch.optim


if __name__ == '__main__':
    input = torch.load('input.pt')
    label = torch.load('label.pt')
    model = torch.load('model.pt')

    # model = nn.Sequential(
    #     nn.Linear(10, 5),
    #     nn.Linear(5, 2),
    #     nn.Linear(2, 1),
    # )

    
    # optimizer = torch.optim.Adam(
    #     params=model.parameters(),
    #     lr=0.1,
    #     betas=(0.9, 0.999),
    #     eps=1e-06,
    #     weight_decay=0
    # )

    optimizer = LAMB(
        params=model.parameters(),
        lr=0.1,
        betas=(0.9, 0.999),
        eps=1e-06,
        weight_decay=0
    )

    model.to(0)
    optimizer = FusedOptimizer(optimizer, do_flatten=True)
    input = input.to(0)
    label = label.to(0)

    print('original:')
    print(optimizer.param_groups[0]['params'][0])

    for i in range(10):
        print('running new step')
        optimizer.zero_grad()
        output = model(input)
        loss = (output - label).pow(2).sum()
        loss.backward()
        
        optimizer.step()
        print(optimizer.param_groups[0]['params'][0])
commented

Thanks! An example is super useful for us to debug. Could you help provide the pt files also?

It seems that lamb optimizer usesweight_norm that is related to the param itself, so if we group them into one big tensor, weight_norm will change. Any idea to do param fusion in this case?

  weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
  adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
  if group['weight_decay'] != 0:
      adam_step.add_(p.data, alpha=group['weight_decay'])
  adam_norm = adam_step.pow(2).sum().sqrt()

  if weight_norm == 0 or adam_norm == 0:
      trust_ratio = 1
  else:
      trust_ratio = weight_norm / adam_norm

  state['weight_norm'] = weight_norm
  state['adam_norm'] = adam_norm
  state['trust_ratio'] = trust_ratio
  
  p.data.add_(adam_step, alpha=-step_size * trust_ratio)
commented

@wangraying is working on a less intrusive way to implement fused optimizer in #207. Let's see whether that works in this case.

In the worst case we can still go with the easiest solution (that would be disabling fusing for weight_norm operations). But that's the last resort we don't want to do at this stage 😃

@ProHuper We fixed fused optimizer in master, would you please check it again? Please let us know if it does not work as expected. BTW, the API changed a bit, FusedOptimizer is replaced by fuse_optimizer, here is the doc.

Thanks a lot.

Thanks, I tried the new API, and it still didn't work right. Also, I got the warning below:
WeChat4d3f2e036ac3f80c48f48eb3034e136f

I still think it's related to the implementation of the LAMB optimizer, it needs a weight_norm factor calculated from each param.

class Lamb(Optimizer):
    r"""Implements Lamb algorithm.
    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        adam (bool, optional): always use trust ratio = 1, which turns this into
            Adam. Useful for comparison purposes.
    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
                 weight_decay=0, adam=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        self.adam = adam
        super(Lamb, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                # m_t
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                # v_t
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Paper v3 does not use debiasing.
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
                # Apply bias to lr to avoid broadcast.
                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1

                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)

                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
                if group['weight_decay'] != 0:
                    adam_step.add_(p.data, alpha=group['weight_decay'])

                adam_norm = adam_step.pow(2).sum().sqrt()
                if weight_norm == 0 or adam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / adam_norm
                state['weight_norm'] = weight_norm
                state['adam_norm'] = adam_norm
                state['trust_ratio'] = trust_ratio
                if self.adam:
                    trust_ratio = 1

                p.data.add_(adam_step, alpha=-step_size * trust_ratio)

        return loss

oh, that's embarrassing. I'll look into this problem soon.

The fused optimizer makes an assumption that parameter and its state tensors should have the same data type and size (which is the case for all Pytorch official optimizers).

The Lamb optimizer in your case has two states weight_norm, adam_norm which does not satisfy this assumption.

However, we can easily make it compliant by changing the following two lines in the code you provided:

state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm

to

state['weight_norm'] = weight_norm.item()
state['adam_norm'] = adam_norm.item()

Note that by doing this the weight_normand adam_norm will be calculated based on the "fused tensors", which is not exactly the same as calculating them for original individual tensors.

Let us know if it works! Thanks

I'll close this if no more problems raised.