teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch

Home Page:https://pypi.org/project/torchsort/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Accuracy sensitive to input size and regularization strength

Forrest-110 opened this issue · comments

Dear authors, you have done a great job! But I met some difficulties when dealing with large size of input (1k-5k).

I have a list of numbers ranging in [0,1], and the list size is about 1k to 5k. And I'm trying to find a suitable regularization strength to get accurate ranking results.

My experiment codes are listed below

import torchsort
import torch
from matplotlib import pyplot as plt
def test_torchsort():
    x = torch.rand(1, 1000).cuda()
    order=torchsort.soft_rank(x,regularization_strength=1e-7).type(torch.int32)
    gt_order=torch.argsort(torch.argsort(x,dim=1),dim=1)+1
    error=torch.abs(order-gt_order).sum().item()
    return error

def test_stablity():
    errors=[]
    error_cnt=0
    for i in range(100):
        error=test_torchsort()
        errors.append(error)
        if error>0:
            error_cnt+=1
    accuracy=1-error_cnt/len(errors)
    print(accuracy)
    plt.hist(errors)
    plt.show()
if __name__ == '__main__':
    test_stablity()

The things is:

  • when input size is around 1k, 1e-7 regularization strength is suitable with accuracy about 0.93
  • but when input size increases to 2k , 1e-7 is no longer suitable. In fact, it's hard to find a good regularization strength

That's exactly the issue. If I have to find a regularization strength for every input size, it could be quite impracticable. Could you help me?

Hey @Forrest-110, apologies for the very late reply I must have missed this. If you are still working on this, do you notice a similar trend with your downstream performance when doing inference? Typically you would only want to use soft rank/sort while training and the gradients are needed. Once trained, you can just use torch.argsort as you do above. I would recommend tuning the regularization using a cross validation set.