the memory consume question
KimmiShi opened this issue · comments
Hi, I am trying to use actnn on transformer models, and I am testing it on a simple nn.linear
module:
import torch
import torch.nn as nn
import torch.nn.functional as F
import actnn
from actnn import config, QScheme, QModule
class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
# self.proj = LoRACompatibleLinear(dim_in, dim_out)
self.proj = nn.Linear(dim_in, dim_out)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
tmp = self.proj(hidden_states)
# print((tmp.numel()-hidden_states.numel())*4/1e6)
hidden_states, gate = tmp.chunk(2, dim=-1)
# import pdb;pdb.set_trace()
return hidden_states * self.gelu(gate)
def test_m():
model = GEGLU(640, 5120)
model = QModule(model)
model.cuda()
inp = torch.rand(128, 2304, 640).cuda()
_ = model(torch.rand(2, 2304, 640).cuda())
# out.mean().backward()
beg = torch.cuda.memory_allocated()/1e6
out = model(inp)
print("memory:", torch.cuda.memory_allocated()/1e6-beg)
# print(model.proj.weight.grad.numel()/1e6)
# out.mean().backward()
actnn.set_optimization_level("L3")
test_m()
How ever, the memory consume I see through the code above does not change when I use or comment model = QModule(model)
,
from example:
- with Qmodule: 2843.49MB
- without Qmodule: 2829MB
I printed in actnn/actnn/ops.py
how much memory was saved after quantized = quantize_activation(input, scheme)
, the quantized size was much smaller than input size, there should be about 700MB saved, but I didn't see this difference on the result above.
It seems that memory save does not work when there is only one nn.linear
I did another experiment in a real module, it seems that actnn only works for a certain structure:
for example,the module defined below:
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out = None,
mult: int = 4,
dropout: float = 0.0,
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
act_fn = nn.Linear(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
beg = torch.cuda.memory_allocated()/1e6
hidden_states = module(hidden_states)
print("module", type(module), torch.cuda.memory_allocated()/1e6-beg)
return hidden_states
result with Qmodule:
module <class 'actnn.layers.QLinear'> 3079.9303680000003
module <class 'actnn.layers.QDropout'> 94.37183999999979
module <class 'actnn.layers.QLinear'> 3220.439552
memory: 4884.79232
result of basline:
module <class 'torch.nn.modules.linear.Linear'> 3028.41856
module <class 'torch.nn.modules.dropout.Dropout'> 0.0
module <class 'torch.nn.modules.linear.Linear'> 6039.7977599999995
memory: 7558.266879999999
only the last Linear consumes less memory, can anyone tell me why not all linear are quantized to 2bit?
the first layer the saved activation is just a reference