THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.

Home Page:https://THUDM.github.io/SwissArmyTransformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Questions about your LoRA codes

miznchimaki opened this issue · comments

I read your LoRA codes in sat/model/finetune/lora2.py directory carefully, but I really have some question about the LoRA code when using Model Parallel to train/test.

  1. For example, when the base class of linear layer is ColumnParallelLinear, the original weight matrix of linear layer (matrix W) is partitioned as the following manner:
    W = [W_1, ..., W_p]
    My thought is: Although the original weight matrix W is partitioned across the model parallel process group, the LoRA matrix A of the original weight should have only one. If my thought is right, there is a conflict between my thought and your LoRA code implementation: in the code of line 101, you used the partitioned weight of matrix W to create the LoRA matrix A. If the original weight matrix W is partitioned/divided into n parts, there are also n different LoRA matrix A, each of which is located in a model parallel process. What's more, the n LoRA matrix A among different model parallel processes may have absolutely different value. The same applys to the LoRA matrix B.
    self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]))) for _ in range(partition)])
  2. If my first opinion is right, your forward code of LoRA linear layer may have another question: In the code of line 131, after you apply multiplication between input x and LoRA matrix A, you apply the copy_to_model_parallel_region function on the multiplication results. This function uses an all_reduce collective operation on gradient during the backward time. The LoRA matrix A in every model parallel process is different from each other, i.e. the output of multiplication between input x and LoRA matrix A is different, can we directly use the all_reduce during the backward time?
    lora_outputs.append((copy_to_model_parallel_region(x @ mA.T) @ mB.T) * self.scaling)

    Looking forward for you replying!

Your thought is absolutely right. But the multiple A doesn't mean multiple A across devices. It's just "A for query", "A for key" and "A for value" since our implementation of query_key_value is a single linear. I think the A of q, k, v should be independent, although most implementation doesn't take this into account (even in peft).

Your thought is absolutely right. But the multiple A doesn't mean multiple A across devices. It's just "A for query", "A for key" and "A for value" since our implementation of query_key_value is a single linear. I think the A of q, k, v should be independent, although most implementation doesn't take this into account (even in peft).

Yeah you're right! so you used a for loop (for _ in range(parttion)) in the code of line 101, the variable partition means q, or q,k,v or k,v:

self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]))) for _ in range(partition)])

If I want to use LoRA with model parallel training in your sat lib, maybe I need to modify the code of LoraLinear class or even more, in order to ensure all model parallel process(, which divides a whole linear layer into several ColumnParallelLinear or RowParallelLinear layer)to use only one LoRA A & B matrix of q/k/v).
Otherwise, I can also divide the LoRA A & B matrix among all model parallel processes, but this scheme may have a potential problem: the value of LoRA's hyperparameter r may not be completely divided by or even smaller than the number of processes in a model parallel group.

No. You don't need to modify anything. This LoraMixin supports model parallel setting. Just build your model with model parallel. And then add the mixin:

model = YourModule() # This should be an SAT model
model.add_mixin("lora", LoraMixin(xxx))

No. You don't need to modify anything. This LoraMixin supports model parallel setting. Just build your model with model parallel. And then add the mixin:

model = YourModule() # This should be an SAT model
model.add_mixin("lora", LoraMixin(xxx))

But I read your codes carefully again, I still have question.
My question mainly concerns on the following point: I have known that, when the number of model parallel processes is greater than 1(, such as 4), each model parallel process holds its own ColumnParallelLinear instance of q/k/v layer or RowParallelLinear instance of dense layer and of course, each ColumnParallelLinear/RowParallelLinear has its own parameters which differs from others. But following your codes, I think each model parallel process also holds its own LoRA matrix A&B corresponding to the q/k/v/dense layer. What's more, params of LoRA matrix A&B on each model parallel process also differs from others. This means that when I view CollumnParallelLinear/RowParallelLinear instances of all model parallel process as a whole linear layer in my mind (in fact they are originally a whole linear, which is divided on multi parallel processes by our human codes), this whole linear has n LoRA matrix A&B (n equals the number of model parallel process). All n LoRA matrix A/B has same shape, but differs in the value of params. So this conflicts with my original thought: Although the original weight matrix W is partitioned across the model parallel process group, the LoRA matrix A of the original weight should have only one. Below is my chain of thougt with reference of your codes:

  1. model.add_mixin("lora", LoraMixin(xxx)) will execute the replace_linear_with_lora in your lib:
    parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.layers[i].cross_attention.inner_hidden_size, out_size=parent_model.transformer.hidden_size)
  2. replace_linear_with_lora will instantiate the class LoraLinear, and the __init__ func of LoraLinear will create a ParameterList of LoRA matrix A or B (of course the partition in the for loop corresponds to q/kv/qkv/dense):
    self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]))) for _ in range(partition)])
  3. In the above line of code, the original_obj may be an instance of ColumnParallelLinear/RowParallelLinear. So take the instance of ColumnParallelLinear as an example, the shape of weight in instance of ColumnParallelLinear is
    (out_channel / num_model_parallel_processes, in_channel),
    so the shape of LoRA matrix A of q/k/v in each model parallel process is
    (lora_r, in_channel),
    and the shape of LoRA matrix B of q/k/v in each model parallel process is
    (out_channel / num_model_parallel_process, lora_r).
    This implys that each model parallel process has its own LoRA matrix A of q/k/v with shape (lora_r, in_channel), and B of q/k/v with shape (out_channel / num_model_parallel_process, lora_r).

Above is my all questions and related thoughts, I still don't know what I said is right or wrong. If wrong, where is the bug of my thought...
In fact I like your source code of SwissArmyTransformer very much, I learned a lot about the model parallel of self-attention through your sat source code, but I just cannot understand this question.
I would greatly appreciate you and your source code if receiving your reply~~~

Good question. This is why this line of code contains a copy_to_model_parallel_region:

lora_outputs.append((copy_to_model_parallel_region(x @ mA.T) @ mB.T) * self.scaling)

For each process, B differs because it should be different. You can see it as split across devices, as indicated by B.model_parallel = True (which will be tackled in deepspeed automatically):

self.matrix_B[i].model_parallel = True

For A, it's same across devices, because the existence of copy_to_model_parallel_region where the gradient will be all_reduced among devices before passing back to A.

Now that the initialization of A is same (which is tackled in SAT as shown below), and the gradient is same all the time. A keeps same during training.

print_rank0('Syncing initialized parameters...')
for param_group in param_groups:
for param in param_group['params']:
if not param.model_parallel:
# We already keep the same random seed for different ranks
# However, it is not reliable. Non-model-parallel parameters could be different when initialization.
dist.broadcast(
param.data,
src=0, # group is default group
)
else:
dist.broadcast(
param.data,
src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1
group=mpu.get_data_parallel_group() # 1, mp_size + 1, ...
)

Good question. This is why this line of code contains a copy_to_model_parallel_region:

lora_outputs.append((copy_to_model_parallel_region(x @ mA.T) @ mB.T) * self.scaling)

For each process, B differs because it should be different. You can see it as split across devices, as indicated by B.model_parallel = True (which will be tackled in deepspeed automatically):

self.matrix_B[i].model_parallel = True

For A, it's same across devices, because the existence of copy_to_model_parallel_region where the gradient will be all_reduced among devices before passing back to A.

Now that the initialization of A is same (which is tackled in SAT as shown below), and the gradient is same all the time. A keeps same during training.

print_rank0('Syncing initialized parameters...')
for param_group in param_groups:
for param in param_group['params']:
if not param.model_parallel:
# We already keep the same random seed for different ranks
# However, it is not reliable. Non-model-parallel parameters could be different when initialization.
dist.broadcast(
param.data,
src=0, # group is default group
)
else:
dist.broadcast(
param.data,
src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1
group=mpu.get_data_parallel_group() # 1, mp_size + 1, ...
)

Thks a lot!!! I've understood it thorougly.
In fact I've also thought whether LoRA matrix A among different model parallel process have the same initialized weights, but I didn't find the codes that set all weights whose model_parallel attribute is false euqal among model parallel processes. In your source codes, you used the torch.distributed.broadcast function to realize it, across all processes, i.e. both data and model parallel processes (the group param of broadcast func is None, meaning the default process group, i.e. the whole processes group).

Taking LoRA matrix A as an example, it takes the same input x in a model parallel process group but different inputs among data parallel processes, and the value of LoRA matrix A across all processes (both model and data parallelism) are absolutely equal. According to the Chain Rule, the params of LoRA matrix A should receive all gradients from all parallel processes (both data and model parallelism), so you used copy_to_model_parallel_region to reduce (sum up) the gradients of processes in a model parallel group. As for summing then averaging operation of gradients across data parallel processes, it may be done by the internal codes of deepspeed automatically.
After discussing with you, I've learned much more, thks very much, hhh

Cool.