rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations

Home Page:https://pytorch-scatter.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Segment operations for matrix multiplication instead of reduction

sidnb13 opened this issue · comments

I would like to implement an optimized operation to perform what segment_coo or segment_csr does, but apply matrix multiplication instead of the available reductions. Here is an example, and how I'm currently doing it. It is the fastest way I could conceive so far, but I would like parallelize across the number of available weights instead of looping over them.

import torch

input_dim, output_dim = 4, 8
weights = [torch.randn(input_dim, output_dim) for _ in range(3)]
# assigns each feature to a weight
indptr = torch.tensor([0, 0, 1, 1, 2, 2, -1, -1]) # -1 means a padding index so can use any specified weight
features = torch.randn(indptr.shape[0], input_dim)

out = torch.zeros(features.shape[0], output_dim)
for weight in weights:
    out_ = features @ weight
    out += out_[indptr != i, :]

Would appreciate any guidance on how implement a custom operation/kernel to do the above.

Great, this is exactly what I needed!