rapidsai / raft

RAFT contains fundamental widely-used algorithms and primitives for machine learning and information retrieval. The algorithms are CUDA-accelerated and form building blocks for more easily writing high performance applications.

Home Page:https://docs.rapids.ai/api/raft/stable/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] knn_merge_parts not implemented for k>1024

tfeher opened this issue · comments

Describe the bug

knn_merge_parts is only implemented for k<=1024, and it silently returns without doing any work for k>1024:

{
if (k == 1)
knn_merge_parts_impl<value_idx, value_t, 1, 1>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 32)
knn_merge_parts_impl<value_idx, value_t, 32, 2>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 64)
knn_merge_parts_impl<value_idx, value_t, 64, 3>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 128)
knn_merge_parts_impl<value_idx, value_t, 128, 3>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 256)
knn_merge_parts_impl<value_idx, value_t, 256, 4>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 512)
knn_merge_parts_impl<value_idx, value_t, 512, 8>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else if (k <= 1024)
knn_merge_parts_impl<value_idx, value_t, 1024, 8>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

knn_merge_parts is used during brute force search if:

  • an offset index needs to be added to the indices. This feature is used in raft_ann_bench in the ground_truth generator. There is an easy workaround for this case.
  • When merging results from multi-gpu search (#1993, tagging @viclafargue for visibility)

Expected behavior

  • Throw an error if input parameters are out of range
  • Ideally improve knn_merge_parts, so that it handles larger k.

Additional context
The limitation comes from the fact that we are using faiss::blockselect. We have radix top-k that does not have a limit on k, so we just need to map the input to a format that radix-topk can consume.