epfml / powersgd

Practical low-rank gradient compression for distributed optimization: https://arxiv.org/abs/1905.13727

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PowerSGD is similiar efficient with torch.svd_lowrank

HaoKang-Timmy opened this issue · comments

Here I write a function of power iteration.



import torch
import time
def poweriter(input, p_buffer, q_buffer, iter):
    for i in range(iter):
        if i == iter - 1:
            p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
        q_buffer[0] = input @ p_buffer[0]
        if i == iter - 1:
            q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
        p_buffer[0] = input.permute((0, 1, 3, 2)) @ q_buffer[0]
    return q_buffer[0] @ p_buffer[0].permute((0, 1, 3, 2))


input = torch.rand([64, 32, 112, 112]).requires_grad_()
p_buffer = torch.rand([64, 32, 112, 3])
q_buffer = torch.rand([64, 32, 112, 3])
for i in range(10):
    output = poweriter(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time:",end - start,"error:",torch.abs(output - input).mean())
input = input.view(64,32,-1)
start = time.time()
U,S,V = torch.svd_lowrank(input, q = 3)
S = torch.diag_embed(S)
V = V.transpose(-1, -2)
output = torch.matmul(U[..., :, :], S[..., :, :])
output = torch.matmul(output[..., :, :], V[..., :, :])
end = time.time()

print("svdlow_time:",end - start,"error:",torch.abs(output - input).mean())

I only do orthogonalization at last and only do 2 iterations but it seems that power iteration is not so much faster than svd_lowrank. And also get more errors than before(if there is no feedback).
The result is

powersvd_time: 0.1424109935760498 error: tensor(0.2390)
svdlow_time: 0.16568613052368164 error: tensor(0.2343)

May I suppose that actually I could use svd_lowrank during training instead of power iteration to get similar results with the paper?
Since I have tried to use svd_lowrank to compress gradients and it shows that in the same situation, svd_lowrank gets better and cost more time( but relatively small).

And here is the 3d version of poweriteration



import torch
import time
def poweriter(input, p_buffer, q_buffer, iter):
    for i in range(iter):
        if i == iter - 1:
            p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
        q_buffer[0] = input @ p_buffer[0]
        if i == iter - 1:
            q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
        p_buffer[0] = input.permute((0, 1, 3, 2)) @ q_buffer[0]
    return q_buffer[0] @ p_buffer[0].permute((0, 1, 3, 2))
def poweriter3d(input, p_buffer, q_buffer, iter):
    shape = input.shape
    input = input.view(int(input.shape[0]), int(input.shape[1]), -1)
    for i in range(iter):
        if i == iter - 1:
            p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
        q_buffer[0] = input @ p_buffer[0]
        if i == iter - 1:
            q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
        p_buffer[0] = input.permute((0, 2, 1)) @ q_buffer[0]
    return (q_buffer[0] @ p_buffer[0].permute((0, 2, 1))).view(shape)

input = torch.rand([64, 32, 112, 112])
p_buffer = torch.rand([64, 32, 112, 3])
q_buffer = torch.rand([64, 32, 112, 3])
for i in range(10):
    output = poweriter(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time:",end - start,"error:",torch.abs(output - input).mean())
input = input.view(64,32,-1)
start = time.time()
U,S,V = torch.svd_lowrank(input, q = 3)
S = torch.diag_embed(S)
V = V.transpose(-1, -2)
output = torch.matmul(U[..., :, :], S[..., :, :])
output = torch.matmul(output[..., :, :], V[..., :, :])
end = time.time()

print("svdlow_time:",end - start,"error:",torch.abs(output - input).mean())
input = torch.rand([64, 32, 112, 112])
p_buffer = torch.rand([64, 12544, 3])
q_buffer = torch.rand([64, 32, 3])
for i in range(10):
    output = poweriter3d(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter3d(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time3d:",end - start,"error:",torch.abs(output - input).mean())

And the result is

powersvd_time: 0.1419539451599121 error: tensor(0.2389)
svdlow_time: 0.1628878116607666 error: tensor(0.2342)
powersvd_time3d: 0.1390700340270996 error: tensor(0.2343)

Power iteration is not that efficient compared to svd_lowrank(it seems)

Thank you for sharing these results. You cannot use SVD as a drop-in replacement for power iteration in PowerSGD.
In PowerSGD, we run power iteration on the average gradient. If you use SVD, you can only run it on the local gradients, and then average those among workers using all-to-all communication.

Thank you for sharing these results. You cannot use SVD as a drop-in replacement for power iteration in PowerSGD. In PowerSGD, we run power iteration on the average gradient. If you use SVD, you can only run it on the local gradients, and then average those among workers using all-to-all communication.

Yes, it is right. Thanks