alexklibisz / elastiknn

Elasticsearch plugin for nearest neighbor search. Store vectors and run similarity search using exact and approximate algorithms.

Home Page:https://alexklibisz.github.io/elastiknn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Try quick select algorithm for KthGreatest implementation

alexklibisz opened this issue · comments

Background

I think there could be an opportunity to speed up approximate queries by re-implementing the kthGreatest method using the quick select algorithm.

At a high level, the kthGreatest method is used to find the kth greatest document frequency. We give it an array of counts, each one representing the number of times a distinct document was matched against a set of query terms. It returns the kth greatest count. Then we perform exact similarity scoring on each of the documents that match or exceed this kth greatest count.

There are some good example implementations of quick select on Leet code:

Deliverables

  • Implement and benchmark kthGreatest method using quick select
  • Report the results on this ticket or a PR, if it's good enough to merge

Related Issues

Blocked by #525

I've partially implemented this in #603. I based much of the quickselect implementation on this excellent gist: https://gist.github.com/unnikked/14c19ba13f6a4bfd00a3

My latest iteration at time of writing is here:

package com.klibisz.elastiknn.search;
public class QuickSelect {
public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}
private static short recursive(short[] array, int left, int right, int k) {
if (left == right) { // If the list contains only one element,
return array[left]; // return that element
}
// select a pivotIndex between left and right
int pivotIndex = left + (right - left) / 2;
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
return array[k];
} else if (k < pivotIndex) {
return recursive(array, left, pivotIndex - 1, k);
} else {
return recursive(array, pivotIndex + 1, right, k);
}
}
private static int partition(short[] array, int left, int right, int pivotIndex) {
int pivotValue = array[pivotIndex];
swap(array, pivotIndex, right); // move pivot to end
int storeIndex = left;
for(int i = left; i < right; i++) {
if(array[i] > pivotValue) {
swap(array, storeIndex, i);
storeIndex++;
}
}
swap(array, right, storeIndex); // Move pivot to its final place
return storeIndex;
}
private static void swap(short[] array, int a, int b) {
short tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}

The benchmark is here:

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = {
System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length)
QuickSelect.selectRecursive(f.copy, f.k)
()
}

Unfortunately this particular implementation of the quickselect algorithm is somehow actually slower than just sorting. I would speculate that much of this comes from the fact I hae to make a full copy of the array at every iteration. This is necessary as the quickselect method is modifying the array (swapping around values) in order to compute its result, whereas the ArrayHitCounter expects those values to be immutable.

[info] Benchmark                                Mode  Cnt      Score    Error  Units
[info] KthGreatestBenchmarks.kthGreatest       thrpt    5  10796.563 ±  0.931  ops/s
[info] KthGreatestBenchmarks.sortBaseline      thrpt    5   2965.035 ± 85.854  ops/s
[info] KthGreatestBenchmarks.unnikedRecursive  thrpt    5   2171.902 ± 40.547  ops/s
[success] Total time: 153 s (02:33), completed Nov 26, 2023, 11:35:28 PM

Quickselect is about 30% faster when I switch from a fixed pivot to a random pivot, line 19:

package com.klibisz.elastiknn.search;
import java.util.Random;
public class QuickSelect {
private static final Random rng = new Random(0);
public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}
private static short recursive(short[] array, int left, int right, int k) {
if (left == right) { // If the list contains only one element,
return array[left]; // return that element
}
// select a pivotIndex between left and right
int pivotIndex = left + rng.nextInt(right - left);
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
return array[k];
} else if (k < pivotIndex) {
return recursive(array, left, pivotIndex - 1, k);
} else {
return recursive(array, pivotIndex + 1, right, k);
}
}
private static int partition(short[] array, int left, int right, int pivotIndex) {
int pivotValue = array[pivotIndex];
swap(array, pivotIndex, right); // move pivot to end
int storeIndex = left;
for(int i = left; i < right; i++) {
if(array[i] > pivotValue) {
swap(array, storeIndex, i);
storeIndex++;
}
}
swap(array, right, storeIndex); // Move pivot to its final place
return storeIndex;
}
private static void swap(short[] array, int a, int b) {
short tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}

[info] Benchmark                                Mode  Cnt     Score    Error  Units
[info] KthGreatestBenchmarks.unnikedRecursive  thrpt    5  2788.150 ± 16.839  ops/s
[success] Total time: 51 s, completed Nov 26, 2023, 11:48:26 PM

But it still doesn't touch kthGreatest

This feels similar to using hashmaps instead of arrays to count hits, summarized in this comment: #160 (comment)

Right now I'm benchmarking w/ a dataset of 60k vectors (Fashion Mnist). Optimizations like quickselect and primitive hashmaps might make a positive impact when I'm working with far more vectors. But Fashion Mnist is the benchmark I'm trying to optimize for now.

Closing this for now. Might re-open if/when I'm benchmarking on a larger dataset.