Segment operations for matrix multiplication instead of reduction
sidnb13 opened this issue · comments
Sidharth Baskaran commented
I would like to implement an optimized operation to perform what segment_coo
or segment_csr
does, but apply matrix multiplication instead of the available reductions. Here is an example, and how I'm currently doing it. It is the fastest way I could conceive so far, but I would like parallelize across the number of available weights instead of looping over them.
import torch
input_dim, output_dim = 4, 8
weights = [torch.randn(input_dim, output_dim) for _ in range(3)]
# assigns each feature to a weight
indptr = torch.tensor([0, 0, 1, 1, 2, 2, -1, -1]) # -1 means a padding index so can use any specified weight
features = torch.randn(indptr.shape[0], input_dim)
out = torch.zeros(features.shape[0], output_dim)
for weight in weights:
out_ = features @ weight
out += out_[indptr != i, :]
Would appreciate any guidance on how implement a custom operation/kernel to do the above.
Matthias Fey commented
Sidharth Baskaran commented
Great, this is exactly what I needed!