rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations

Home Page:https://pytorch-scatter.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Question] `scatter_logsumexp` using only `torch.scatter_reduce` (& `torch.scatter_add`?)

jeanmonet opened this issue · comments

This is more of a question than an issue. Not (yet) being hyper-familiar with the API, I was looking for guidance on how to reproduce torch_scatter.scatter_logsumexp using only torch.scatter_reduce (& torch.scatter_add?) instead of torch_scatter.scatter_sum and torch_scatter.scatter_max?

def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
if out is not None:
dim_size = out.size(dim)
else:
if dim_size is None:
dim_size = int(index.max()) + 1
size = list(src.size())
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
if out is not None:
out = out.sub_(max_value_per_index).exp_()
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index)

I guess the resulting adaptation would look like this:

https://github.com/jeanmonet/NSNet/blob/5ebaedcb439e9db613d6c49aac50d8013de92127/src/utils/scatter.py#L73-L128

The API change involving mostly:

  • adding code to create the out tensor if it is not provided
  • max_value_per_index.scatter_reduce_(dim=dim, index=index, src=src, reduce="amax", include_self=False)
  • sum_per_index = out.scatter_add_(src=recentered_score.exp_(), index=index, dim=dim)