ucbrise / actnn

ActNN: Reducing Training Memory Footprint via 2-Bit Activation Compressed Training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Does this work well for Transformers?

prajwalkr opened this issue · comments

Has this new method been tested with Transformers?

The method has been developed primarily for convolutional networks. It can be used for transformers, but it can only substitute all linear layers within them, so there won't be a lot of memory saving (I think less than 2x).

We are working on a general version of ActNN, which automatically quantizes all the layers. Please stay tuned.

The method has been developed primarily for convolutional networks. It can be used for transformers, but it can only substitute all linear layers within them, so there won't be a lot of memory saving (I think less than 2x).

This is because activation memory usage is higher in CNN, in contrast, weight is majority in transformers?

We are working on a general version of ActNN, which automatically quantizes all the layers. Please stay tuned.

Cool, this sounds interesting! May I know the ideas about how to automatically quantizes all layers? Because I read the source code and find actnn implementing operators in order to discard input and save the quantized input(activation) in ctx. so seems it's hard to automatically quantizes all layers by add hooks before forward and backward, which is a more elegant way.

In ActNN we only consider compressing the activations. In principle, ActNN can be combined with other frameworks that deals with the weight memory (e.g. ZeRO-Offload, not tested though).

We currently have a dedicated compression strategy for each type of layer (conv, linear, bn, etc.). The strategy is derived analytically for best compression ratio. As the multi-head attention layer is somewhat complicated we didn't derive a compression strategy for it in the current version of the paper. In the next general version we are using a universal compression strategy across all types of layers.

For implementation, it involves modifying PyTorch itself.

In ActNN we only consider compressing the activations. In principle, ActNN can be combined with other frameworks that deals with the weight memory (e.g. ZeRO-Offload, not tested though).

Sure, maybe we can also compress parameters in transformer? I will try to find out whether we can combine ZeRO with ActNN.

We currently have a dedicated compression strategy for each type of layer (conv, linear, bn, etc.). The strategy is derived analytically for best compression ratio. As the multi-head attention layer is somewhat complicated we didn't derive a compression strategy for it in the current version of the paper. In the next general version we are using a universal compression strategy across all types of layers.

For implementation, it involves modifying PyTorch itself.

Got it, currently quantization used in ActNN is linear transform, but self attention is quadratic, so should find a new way to minimize the variance involved by compressing activations in multi-head attention layer.

Compression the parameters with the same algorithm used in ActNN is possible. However I am not sure about the optimizer state and the gradient buffer, which are as large as the weights. So non-compression methods like ZeRO looks more directly applicable in this scenario.

Actually, a simple strategy such as the default "L2" strategy in ActNN should work for the self attention layer. If you are looking for a solution for transformer right now, maybe implementing an autograd.Function for self attention will work as a simple remedy.

hi @cjf00000
when i try to train swin-transformer, i meet this problem

File "actnn/actnn/actnn/ops.py", line 413, in backward
    grad_input = grad_output.mm(weight)
RuntimeError: tensors must be 2-D

hi @cjf00000
when i try to train swin-transformer, i meet this problem

File "actnn/actnn/actnn/ops.py", line 413, in backward
    grad_input = grad_output.mm(weight)
RuntimeError: tensors must be 2-D

@zimenglan-sysu-512
Thanks for reporting this issue, this is now fixed.