PowerSGD is similiar efficient with torch.svd_lowrank
HaoKang-Timmy opened this issue · comments
Here I write a function of power iteration.
import torch
import time
def poweriter(input, p_buffer, q_buffer, iter):
for i in range(iter):
if i == iter - 1:
p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
q_buffer[0] = input @ p_buffer[0]
if i == iter - 1:
q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
p_buffer[0] = input.permute((0, 1, 3, 2)) @ q_buffer[0]
return q_buffer[0] @ p_buffer[0].permute((0, 1, 3, 2))
input = torch.rand([64, 32, 112, 112]).requires_grad_()
p_buffer = torch.rand([64, 32, 112, 3])
q_buffer = torch.rand([64, 32, 112, 3])
for i in range(10):
output = poweriter(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time:",end - start,"error:",torch.abs(output - input).mean())
input = input.view(64,32,-1)
start = time.time()
U,S,V = torch.svd_lowrank(input, q = 3)
S = torch.diag_embed(S)
V = V.transpose(-1, -2)
output = torch.matmul(U[..., :, :], S[..., :, :])
output = torch.matmul(output[..., :, :], V[..., :, :])
end = time.time()
print("svdlow_time:",end - start,"error:",torch.abs(output - input).mean())
I only do orthogonalization at last and only do 2 iterations but it seems that power iteration is not so much faster than svd_lowrank. And also get more errors than before(if there is no feedback).
The result is
powersvd_time: 0.1424109935760498 error: tensor(0.2390)
svdlow_time: 0.16568613052368164 error: tensor(0.2343)
May I suppose that actually I could use svd_lowrank during training instead of power iteration to get similar results with the paper?
Since I have tried to use svd_lowrank to compress gradients and it shows that in the same situation, svd_lowrank gets better and cost more time( but relatively small).
And here is the 3d version of poweriteration
import torch
import time
def poweriter(input, p_buffer, q_buffer, iter):
for i in range(iter):
if i == iter - 1:
p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
q_buffer[0] = input @ p_buffer[0]
if i == iter - 1:
q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
p_buffer[0] = input.permute((0, 1, 3, 2)) @ q_buffer[0]
return q_buffer[0] @ p_buffer[0].permute((0, 1, 3, 2))
def poweriter3d(input, p_buffer, q_buffer, iter):
shape = input.shape
input = input.view(int(input.shape[0]), int(input.shape[1]), -1)
for i in range(iter):
if i == iter - 1:
p_buffer[0] = torch.linalg.qr(p_buffer[0]).Q
q_buffer[0] = input @ p_buffer[0]
if i == iter - 1:
q_buffer[0] = torch.linalg.qr(q_buffer[0]).Q
p_buffer[0] = input.permute((0, 2, 1)) @ q_buffer[0]
return (q_buffer[0] @ p_buffer[0].permute((0, 2, 1))).view(shape)
input = torch.rand([64, 32, 112, 112])
p_buffer = torch.rand([64, 32, 112, 3])
q_buffer = torch.rand([64, 32, 112, 3])
for i in range(10):
output = poweriter(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time:",end - start,"error:",torch.abs(output - input).mean())
input = input.view(64,32,-1)
start = time.time()
U,S,V = torch.svd_lowrank(input, q = 3)
S = torch.diag_embed(S)
V = V.transpose(-1, -2)
output = torch.matmul(U[..., :, :], S[..., :, :])
output = torch.matmul(output[..., :, :], V[..., :, :])
end = time.time()
print("svdlow_time:",end - start,"error:",torch.abs(output - input).mean())
input = torch.rand([64, 32, 112, 112])
p_buffer = torch.rand([64, 12544, 3])
q_buffer = torch.rand([64, 32, 3])
for i in range(10):
output = poweriter3d(input,[p_buffer],[q_buffer],1)
start = time.time()
output = poweriter3d(input,[p_buffer],[q_buffer],2)
end = time.time()
print("powersvd_time3d:",end - start,"error:",torch.abs(output - input).mean())
And the result is
powersvd_time: 0.1419539451599121 error: tensor(0.2389)
svdlow_time: 0.1628878116607666 error: tensor(0.2342)
powersvd_time3d: 0.1390700340270996 error: tensor(0.2343)
Power iteration is not that efficient compared to svd_lowrank(it seems)
Thank you for sharing these results. You cannot use SVD as a drop-in replacement for power iteration in PowerSGD.
In PowerSGD, we run power iteration on the average gradient. If you use SVD, you can only run it on the local gradients, and then average those among workers using all-to-all communication.
Thank you for sharing these results. You cannot use SVD as a drop-in replacement for power iteration in PowerSGD. In PowerSGD, we run power iteration on the average gradient. If you use SVD, you can only run it on the local gradients, and then average those among workers using all-to-all communication.
Yes, it is right. Thanks