[Question] `scatter_logsumexp` using only `torch.scatter_reduce` (& `torch.scatter_add`?)
jeanmonet opened this issue · comments
jeanmonet commented
This is more of a question than an issue. Not (yet) being hyper-familiar with the API, I was looking for guidance on how to reproduce torch_scatter.scatter_logsumexp
using only torch.scatter_reduce
(& torch.scatter_add
?) instead of torch_scatter.scatter_sum
and torch_scatter.scatter_max
?
pytorch_scatter/torch_scatter/composite/logsumexp.py
Lines 9 to 40 in c38e20a
jeanmonet commented
I guess the resulting adaptation would look like this:
The API change involving mostly:
- adding code to create the
out
tensor if it is not provided max_value_per_index.scatter_reduce_(dim=dim, index=index, src=src, reduce="amax", include_self=False)
sum_per_index = out.scatter_add_(src=recentered_score.exp_(), index=index, dim=dim)