rusty1s / pytorch_unique

PyTorch Extension Library of Optimized Unique Operation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Minor] Usage of torch.unique workaround gives different results. Work-around provided

sairaamVenkatraman opened this issue · comments

The code you've given for the torch.unique to get the indices returns the indices of the last index in the original array, whereas torch_unique.unique returns the first indices. An example

x = torch.tensor([100, 10, 100, 1, 1000, 1, 1000, 1])
out, inverse = torch.unique(x, sorted=True, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
perm = inverse.new_empty(out.size(0)).scatter_(0, inverse, perm)

Here,
out = tensor([ 1, 10, 100, 1000])
perm = tensor([7, 1, 2, 6])

Running the code with torch_unique gives different results:

out1, perm1 = torch_unique.unique(x)
Here,
out1 = tensor([ 1, 10, 100, 1000])
perm1 = tensor([3, 1, 0, 4])

A workaround for this is

x = torch.tensor([100, 10, 100, 1, 1000, 1, 1000, 1])
out, inverse = torch.unique(x, sorted=True, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
perm = torch.flip(inverse.unsqueeze(-1),[0,1]).view(-1,1).squeeze(-1)
inverse = torch.flip(inverse.unsqueeze(-1),[0,1]).view(-1,1).squeeze(-1)
perm = inverse.new_empty(out.size(0)).scatter_(0, inverse, perm)

Thank you. I will add it to the readme.