[Question] Async Tensor Parallel
woshiyyya opened this issue · comments
I am a bit curious about the below example in 3d_parallelism.md
Q: Can you give a concrete example illustrating how asynchronous tensor parallelism works? (6 steps)
A:
Step 1: Let's look at an example with 4 GPU ranks:
Input X sharded across ranks as [X0, X1, X2, X3]
Weight matrix W sharded as [W0, W1, W2, W3]
Step 2: Rank 2 kicks off async all-gather to get [X0, X1, X2, X3]
Step 3: While gathering, rank 2 computes: local_output = X2 * W2
Step 4: All-gather completes, rank 2 has [X0, X1, X2, X3]
Step 5: Rank 2 computes: before_local_output = X0 * W0 + X1 * W1, after_local_output = X3 * W3
Step 6: Rank 2's output = before_local_output + local_output + after_local_output
So each rank computes the full output using the locally gathered X and its shard of W.
Since W is also sharded across different ranks (rank i possesses the Wi), in which step do we gather the sharded weights?
Just want to confirm if my understanding is correct. Suppose TP=4, you have a input of shape [B x M]
, and a weight of shape [M x N]
. Then the shape of sharded input Xi be [B, M/4]
and Wi be [M/4, N]
? Is it row-based tensor parallel?
@woshiyyya Hi. Sorry for the deplayed response :)
Is it row-based tensor parallel
Yes. Currently we only use async communication in row-based [link]
Then the shape of sharded input Xi be [B, M/4] and Wi be [M/4, N]?
Nope. The sharded input Xi is [B/4, M], and W_i is [M, N/4].
Thanks for the reply~ I am still confused, since from the code I saw the row-based linear layer should have weights in shape [M/4, N]
.
class TensorParallelRowLinear(nn.Linear):
...
self.in_features = in_features // self.world_size
self.out_features = out_features
And if rank 2's output is X0 * W0 + X1 * W1 + X2 * W2 + X3 * W3
, that will be of shape [B/4, N/4]
which is only 1/16 of the output?