rusty1s / pytorch_cluster

PyTorch Extension Library of Optimized Graph Cluster Algorithms

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

can k >100 in knn?

heyfavour opened this issue · comments

when i use knn to calculate atom neighbour,it's easy count > 100,can i set k > 100?
in your cuda code u set k<=100

https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cuda/knn_cuda.cu line 98
AT_ASSERTM(k <= 100, "k needs to smaller than or equal to 100");

Yes, currently we fix the maximum number of neighbors to 100. One way to increase this would be via templating, so that we can create specific kernels depending on the number of k, e.g.

template <typename, scalar_t, int max_k>
knn_kernel(...) {

}

torch::Tensor knn_cuda(...) {
  if (k <= 32) {
      knn_kernel<scalar_t, 32>(...);
  } else if (k <= 64) {
      knn_kernel<scalar>t, 64>(...);
  } ...
}

I don't have time to implement this right now. Do you have interest in adding this solution?