ucbrise / actnn

ActNN: Reducing Training Memory Footprint via 2-Bit Activation Compressed Training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

the memory consume question

KimmiShi opened this issue · comments

Hi, I am trying to use actnn on transformer models, and I am testing it on a simple nn.linear module:

import torch
import torch.nn as nn
import torch.nn.functional as F
import actnn
from actnn import config, QScheme, QModule
class GEGLU(nn.Module):
    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        # self.proj = LoRACompatibleLinear(dim_in, dim_out)
        self.proj = nn.Linear(dim_in, dim_out)

    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

    def forward(self, hidden_states):

        tmp = self.proj(hidden_states)
        # print((tmp.numel()-hidden_states.numel())*4/1e6)
        hidden_states, gate = tmp.chunk(2, dim=-1)
        # import pdb;pdb.set_trace()

        return hidden_states * self.gelu(gate)


def test_m():
    model = GEGLU(640, 5120)
    model = QModule(model)
    model.cuda()

    inp = torch.rand(128, 2304, 640).cuda()

    _ = model(torch.rand(2, 2304, 640).cuda())
    # out.mean().backward()

    beg = torch.cuda.memory_allocated()/1e6

    out = model(inp)
    print("memory:", torch.cuda.memory_allocated()/1e6-beg)
    # print(model.proj.weight.grad.numel()/1e6)
    # out.mean().backward()

actnn.set_optimization_level("L3")

test_m()

How ever, the memory consume I see through the code above does not change when I use or comment model = QModule(model),
from example:

  • with Qmodule: 2843.49MB
  • without Qmodule: 2829MB

I printed in actnn/actnn/ops.py how much memory was saved after quantized = quantize_activation(input, scheme) , the quantized size was much smaller than input size, there should be about 700MB saved, but I didn't see this difference on the result above.

It seems that memory save does not work when there is only one nn.linear

I did another experiment in a real module, it seems that actnn only works for a certain structure:

for example,the module defined below:

class FeedForward(nn.Module):

    def __init__(
        self,
        dim: int,
        dim_out = None,
        mult: int = 4,
        dropout: float = 0.0,
        final_dropout: bool = False,
    ):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        act_fn = nn.Linear(dim, inner_dim)
        self.net = nn.ModuleList([])
        # project in
        self.net.append(act_fn)
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))

    def forward(self, hidden_states):
        for module in self.net:
            beg = torch.cuda.memory_allocated()/1e6
            hidden_states = module(hidden_states)
            print("module", type(module), torch.cuda.memory_allocated()/1e6-beg)

        return hidden_states

result with Qmodule:

module <class 'actnn.layers.QLinear'> 3079.9303680000003
module <class 'actnn.layers.QDropout'> 94.37183999999979
module <class 'actnn.layers.QLinear'> 3220.439552
memory: 4884.79232

result of basline:

module <class 'torch.nn.modules.linear.Linear'> 3028.41856
module <class 'torch.nn.modules.dropout.Dropout'> 0.0
module <class 'torch.nn.modules.linear.Linear'> 6039.7977599999995
memory: 7558.266879999999

only the last Linear consumes less memory, can anyone tell me why not all linear are quantized to 2bit?

the first layer the saved activation is just a reference