rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations

Home Page:https://pytorch-scatter.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

segment_csr is slower than segment_coo on gpu in some cases

xiqi98 opened this issue · comments

A minimal example:

import torch
from torch_scatter import segment_coo, segment_csr

tic = torch.cuda.Event(enable_timing=True)
toc = torch.cuda.Event(enable_timing=True)

size = 1000000
reduced_size = 2

src = torch.randn(size, dtype=torch.float64).cuda()
index = torch.randint(low=0, high=reduced_size, size=(size,), dtype=torch.int64).cuda()
index, _ = index.sort()
indptr = torch.bincount(index + 1)
indptr = indptr.cumsum(-1)

print("index:", index)
print("indptr:", indptr)

tic.record()
res1 = segment_coo(src, index)
toc.record()
torch.cuda.synchronize()
runtime = 1e-3 * tic.elapsed_time(toc)
print(f"segment_coo runtime: {runtime} seconds.")

tic.record()
res2 = segment_csr(src, indptr)
toc.record()
torch.cuda.synchronize()
runtime = 1e-3 * tic.elapsed_time(toc)
print(f"segment_csr runtime: {runtime} seconds.")

print("diff:", (res1 - res2).abs().max())

'''
Output:
index: tensor([0, 0, 0,  ..., 1, 1, 1], device='cuda:0')
indptr: tensor([      0,  499091, 1000000], device='cuda:0')
segment_coo runtime: 0.01930431938171387 seconds.
segment_csr runtime: 0.027863040924072265 seconds.
diff: tensor(3.4049e-11, device='cuda:0', dtype=torch.float64)
'''

The problem is more severe if reduced_size is set to 1, i.e. with output size 1:

'''
Output with reduced_size=1:
index: tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')
indptr: tensor([      0, 1000000], device='cuda:0')
segment_coo runtime: 0.01807222366333008 seconds.
segment_csr runtime: 0.054765567779541016 seconds.
diff: tensor(4.9113e-11, device='cuda:0', dtype=torch.float64)
'''

However, the problem is not so obvious if src.dtype=torch.float32.
Is there any specific reason for this behavior?

segment_csr currently parallelizes in the output dimension, while segment_coo parallelizes in the input dimension. I think it is expected that segment_csr is slower in case you have less but larger groups.

Ok, thanks for the clarification!:smile: