Accuracy sensitive to input size and regularization strength
Forrest-110 opened this issue · comments
Dear authors, you have done a great job! But I met some difficulties when dealing with large size of input (1k-5k).
I have a list of numbers ranging in [0,1], and the list size is about 1k to 5k. And I'm trying to find a suitable regularization strength to get accurate ranking results.
My experiment codes are listed below
import torchsort
import torch
from matplotlib import pyplot as plt
def test_torchsort():
x = torch.rand(1, 1000).cuda()
order=torchsort.soft_rank(x,regularization_strength=1e-7).type(torch.int32)
gt_order=torch.argsort(torch.argsort(x,dim=1),dim=1)+1
error=torch.abs(order-gt_order).sum().item()
return error
def test_stablity():
errors=[]
error_cnt=0
for i in range(100):
error=test_torchsort()
errors.append(error)
if error>0:
error_cnt+=1
accuracy=1-error_cnt/len(errors)
print(accuracy)
plt.hist(errors)
plt.show()
if __name__ == '__main__':
test_stablity()
The things is:
- when input size is around 1k, 1e-7 regularization strength is suitable with accuracy about 0.93
- but when input size increases to 2k , 1e-7 is no longer suitable. In fact, it's hard to find a good regularization strength
That's exactly the issue. If I have to find a regularization strength for every input size, it could be quite impracticable. Could you help me?
Hey @Forrest-110, apologies for the very late reply I must have missed this. If you are still working on this, do you notice a similar trend with your downstream performance when doing inference? Typically you would only want to use soft rank/sort while training and the gradients are needed. Once trained, you can just use torch.argsort
as you do above. I would recommend tuning the regularization using a cross validation set.