segment_csr is slower than segment_coo on gpu in some cases
xiqi98 opened this issue · comments
Zhengxi Zhang commented
A minimal example:
import torch
from torch_scatter import segment_coo, segment_csr
tic = torch.cuda.Event(enable_timing=True)
toc = torch.cuda.Event(enable_timing=True)
size = 1000000
reduced_size = 2
src = torch.randn(size, dtype=torch.float64).cuda()
index = torch.randint(low=0, high=reduced_size, size=(size,), dtype=torch.int64).cuda()
index, _ = index.sort()
indptr = torch.bincount(index + 1)
indptr = indptr.cumsum(-1)
print("index:", index)
print("indptr:", indptr)
tic.record()
res1 = segment_coo(src, index)
toc.record()
torch.cuda.synchronize()
runtime = 1e-3 * tic.elapsed_time(toc)
print(f"segment_coo runtime: {runtime} seconds.")
tic.record()
res2 = segment_csr(src, indptr)
toc.record()
torch.cuda.synchronize()
runtime = 1e-3 * tic.elapsed_time(toc)
print(f"segment_csr runtime: {runtime} seconds.")
print("diff:", (res1 - res2).abs().max())
'''
Output:
index: tensor([0, 0, 0, ..., 1, 1, 1], device='cuda:0')
indptr: tensor([ 0, 499091, 1000000], device='cuda:0')
segment_coo runtime: 0.01930431938171387 seconds.
segment_csr runtime: 0.027863040924072265 seconds.
diff: tensor(3.4049e-11, device='cuda:0', dtype=torch.float64)
'''
The problem is more severe if reduced_size
is set to 1, i.e. with output size 1:
'''
Output with reduced_size=1:
index: tensor([0, 0, 0, ..., 0, 0, 0], device='cuda:0')
indptr: tensor([ 0, 1000000], device='cuda:0')
segment_coo runtime: 0.01807222366333008 seconds.
segment_csr runtime: 0.054765567779541016 seconds.
diff: tensor(4.9113e-11, device='cuda:0', dtype=torch.float64)
'''
However, the problem is not so obvious if src.dtype=torch.float32
.
Is there any specific reason for this behavior?
Matthias Fey commented
segment_csr
currently parallelizes in the output dimension, while segment_coo
parallelizes in the input dimension. I think it is expected that segment_csr
is slower in case you have less but larger groups.
Zhengxi Zhang commented
Ok, thanks for the clarification!:smile: