Questions about your LoRA codes
miznchimaki opened this issue · comments
I read your LoRA codes in sat/model/finetune/lora2.py
directory carefully, but I really have some question about the LoRA code when using Model Parallel to train/test.
- For example, when the base class of linear layer is
ColumnParallelLinear
, the original weight matrix of linear layer (matrix W) is partitioned as the following manner:
W = [W_1, ..., W_p]
My thought is:Although the original weight matrix W is partitioned across the model parallel process group, the LoRA matrix A of the original weight should have only one.
If my thought is right, there is a conflict between my thought and your LoRA code implementation: in the code of line 101, you used the partitioned weight of matrixW
to create the LoRA matrix A. If the original weight matrixW
is partitioned/divided inton
parts, there are alson
different LoRA matrix A, each of which is located in a model parallel process. What's more, then
LoRA matrix A among different model parallel processes may have absolutely different value. The same applys to the LoRA matrix B. - If my first opinion is right, your forward code of LoRA linear layer may have another question: In the code of line 131, after you apply multiplication between
input x
andLoRA matrix A
, you apply thecopy_to_model_parallel_region
function on the multiplication results. This function uses anall_reduce
collective operation on gradient during the backward time. The LoRA matrix A in every model parallel process is different from each other, i.e. the output of multiplication betweeninput x
andLoRA matrix A
is different, can we directly use theall_reduce
during the backward time?
Looking forward for you replying!
Your thought is absolutely right. But the multiple A doesn't mean multiple A across devices. It's just "A for query", "A for key" and "A for value" since our implementation of query_key_value is a single linear. I think the A of q, k, v should be independent, although most implementation doesn't take this into account (even in peft).
Your thought is absolutely right. But the multiple A doesn't mean multiple A across devices. It's just "A for query", "A for key" and "A for value" since our implementation of query_key_value is a single linear. I think the A of q, k, v should be independent, although most implementation doesn't take this into account (even in peft).
Yeah you're right! so you used a for
loop (for _ in range(parttion)) in the code of line 101, the variable partition
means q
, or q,k,v
or k,v
:
If I want to use LoRA with model parallel training in your sat lib, maybe I need to modify the code of
LoraLinear
class or even more, in order to ensure all model parallel process(, which divides a whole linear layer into several ColumnParallelLinear
or RowParallelLinear
layer)to use only one LoRA A & B matrix of q/k/v).Otherwise, I can also divide the LoRA A & B matrix among all model parallel processes, but this scheme may have a potential problem: the value of LoRA's hyperparameter
r
may not be completely divided by or even smaller than the number of processes in a model parallel group.No. You don't need to modify anything. This LoraMixin supports model parallel setting. Just build your model with model parallel. And then add the mixin:
model = YourModule() # This should be an SAT model
model.add_mixin("lora", LoraMixin(xxx))
No. You don't need to modify anything. This LoraMixin supports model parallel setting. Just build your model with model parallel. And then add the mixin:
model = YourModule() # This should be an SAT model model.add_mixin("lora", LoraMixin(xxx))
But I read your codes carefully again, I still have question.
My question mainly concerns on the following point: I have known that, when the number of model parallel processes is greater than 1(, such as 4), each model parallel process holds its own ColumnParallelLinear
instance of q/k/v layer or RowParallelLinear
instance of dense layer and of course, each ColumnParallelLinear
/RowParallelLinear
has its own parameters which differs from others. But following your codes, I think each model parallel process also holds its own LoRA matrix A&B corresponding to the q/k/v/dense layer. What's more, params of LoRA matrix A&B on each model parallel process also differs from others. This means that when I view CollumnParallelLinear
/RowParallelLinear
instances of all model parallel process as a whole linear layer in my mind (in fact they are originally a whole linear, which is divided on multi parallel processes by our human codes), this whole linear has n
LoRA matrix A&B (n
equals the number of model parallel process). All n
LoRA matrix A/B has same shape, but differs in the value of params. So this conflicts with my original thought: Although the original weight matrix W is partitioned across the model parallel process group, the LoRA matrix A of the original weight should have only one.
Below is my chain of thougt with reference of your codes:
model.add_mixin("lora", LoraMixin(xxx))
will execute thereplace_linear_with_lora
in your lib:
replace_linear_with_lora
will instantiate the classLoraLinear
, and the__init__
func ofLoraLinear
will create a ParameterList of LoRA matrix A or B (of course thepartition
in thefor
loop corresponds to q/kv/qkv/dense):
- In the above line of code, the
original_obj
may be an instance ofColumnParallelLinear
/RowParallelLinear
. So take the instance ofColumnParallelLinear
as an example, the shape of weight in instance ofColumnParallelLinear
is
(out_channel / num_model_parallel_processes, in_channel)
,
so the shape of LoRA matrix A of q/k/v in each model parallel process is
(lora_r, in_channel)
,
and the shape of LoRA matrix B of q/k/v in each model parallel process is
(out_channel / num_model_parallel_process, lora_r)
.
This implys that each model parallel process has its own LoRA matrix A of q/k/v with shape(lora_r, in_channel)
, and B of q/k/v with shape(out_channel / num_model_parallel_process, lora_r)
.
Above is my all questions and related thoughts, I still don't know what I said is right or wrong. If wrong, where is the bug of my thought...
In fact I like your source code of SwissArmyTransformer very much, I learned a lot about the model parallel of self-attention through your sat source code, but I just cannot understand this question.
I would greatly appreciate you and your source code if receiving your reply~~~
Good question. This is why this line of code contains a copy_to_model_parallel_region
:
For each process, B differs because it should be different. You can see it as split across devices, as indicated by B.model_parallel = True
(which will be tackled in deepspeed automatically):
For A, it's same across devices, because the existence of copy_to_model_parallel_region
where the gradient will be all_reduced among devices before passing back to A.
Now that the initialization of A is same (which is tackled in SAT as shown below), and the gradient is same all the time. A keeps same during training.
SwissArmyTransformer/sat/training/deepspeed_training.py
Lines 183 to 198 in aa1277e
Good question. This is why this line of code contains a
copy_to_model_parallel_region
:For each process, B differs because it should be different. You can see it as split across devices, as indicated by
B.model_parallel = True
(which will be tackled in deepspeed automatically):For A, it's same across devices, because the existence of
copy_to_model_parallel_region
where the gradient will be all_reduced among devices before passing back to A.Now that the initialization of A is same (which is tackled in SAT as shown below), and the gradient is same all the time. A keeps same during training.
SwissArmyTransformer/sat/training/deepspeed_training.py
Lines 183 to 198 in aa1277e
Thks a lot!!! I've understood it thorougly.
In fact I've also thought whether LoRA matrix A among different model parallel process have the same initialized weights, but I didn't find the codes that set all weights whose model_parallel
attribute is false euqal among model parallel processes. In your source codes, you used the torch.distributed.broadcast
function to realize it, across all processes, i.e. both data and model parallel processes (the group
param of broadcast func is None, meaning the default process group, i.e. the whole processes group).
Taking LoRA matrix A as an example, it takes the same input x
in a model parallel process group but different inputs among data parallel processes, and the value of LoRA matrix A across all processes (both model and data parallelism) are absolutely equal. According to the Chain Rule, the params of LoRA matrix A should receive all gradients from all parallel processes (both data and model parallelism), so you used copy_to_model_parallel_region
to reduce (sum up) the gradients of processes in a model parallel group. As for summing then averaging
operation of gradients across data parallel processes, it may be done by the internal codes of deepspeed automatically.
After discussing with you, I've learned much more, thks very much, hhh
Cool.