Skip to content

Commit

Permalink
Add support for incremental models
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Dec 30, 2022
1 parent ffea776 commit 1db0119
Show file tree
Hide file tree
Showing 39 changed files with 289 additions and 189 deletions.
2 changes: 1 addition & 1 deletion hnswlib-core-jdk17/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion hnswlib-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import com.github.jelmerk.knn.*;
import com.github.jelmerk.knn.util.*;
import com.github.jelmerk.knn.util.BitSet;
import org.eclipse.collections.api.list.primitive.MutableIntList;
import org.eclipse.collections.api.map.primitive.MutableObjectIntMap;
import org.eclipse.collections.api.map.primitive.MutableObjectLongMap;
Expand Down Expand Up @@ -67,9 +66,9 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance

private ReentrantLock globalLock;

private GenericObjectPool<BitSet> visitedBitSetPool;
private GenericObjectPool<ArrayBitSet> visitedBitSetPool;

private BitSet excludedCandidates;
private ArrayBitSet excludedCandidates;

private ExactView exactView;

Expand Down Expand Up @@ -103,7 +102,7 @@ private HnswIndex(RefinedBuilder<TId, TVector, TItem, TDistance> builder) {
this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
Runtime.getRuntime().availableProcessors());

this.excludedCandidates = new SynchronizedBitSet(new ArrayBitSet(this.maxItemCount));
this.excludedCandidates = new ArrayBitSet(this.maxItemCount);

this.exactView = new ExactView();
}
Expand Down Expand Up @@ -250,7 +249,9 @@ public boolean add(TItem item) {

int newNodeId = nodeCount++;

excludedCandidates.add(newNodeId);
synchronized (excludedCandidates) {
excludedCandidates.add(newNodeId);
}

Node<TItem> newNode = new Node<>(newNodeId, connections, item, false);

Expand Down Expand Up @@ -339,7 +340,9 @@ public boolean add(TItem item) {
}
}
} finally {
excludedCandidates.remove(newNodeId);
synchronized (excludedCandidates) {
excludedCandidates.remove(newNodeId);
}
}
} finally {
if (globalLock.isHeldByCurrentThread()) {
Expand All @@ -363,8 +366,10 @@ private void mutuallyConnectNewElement(Node<TItem> newNode,
while (!topCandidates.isEmpty()) {
int selectedNeighbourId = topCandidates.poll().nodeId;

if (excludedCandidates.contains(selectedNeighbourId)) {
continue;
synchronized (excludedCandidates) {
if (excludedCandidates.contains(selectedNeighbourId)) {
continue;
}
}

newItemConnections.add(selectedNeighbourId);
Expand Down Expand Up @@ -519,10 +524,34 @@ public List<SearchResult<TItem, TDistance>> findNearest(TVector destination, int
return results;
}

/**
* Changes the maximum capacity of the index.
* @param newSize new size of the index
*/
public void resize(int newSize) {
globalLock.lock();
try {
this.maxItemCount = newSize;

this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
Runtime.getRuntime().availableProcessors());

AtomicReferenceArray<Node<TItem>> newNodes = new AtomicReferenceArray<>(newSize);
for(int i = 0; i < this.nodes.length(); i++) {
newNodes.set(i, this.nodes.get(i));
}
this.nodes = newNodes;

this.excludedCandidates = new ArrayBitSet(this.excludedCandidates, newSize);
} finally {
globalLock.unlock();
}
}

private PriorityQueue<NodeIdAndDistance<TDistance>> searchBaseLayer(
Node<TItem> entryPointNode, TVector destination, int k, int layer) {

BitSet visitedBitSet = visitedBitSetPool.borrowObject();
ArrayBitSet visitedBitSet = visitedBitSetPool.borrowObject();

try {
PriorityQueue<NodeIdAndDistance<TDistance>> topCandidates =
Expand Down Expand Up @@ -778,7 +807,7 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound
this.globalLock = new ReentrantLock();
this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
Runtime.getRuntime().availableProcessors());
this.excludedCandidates = new SynchronizedBitSet(new ArrayBitSet(this.maxItemCount));
this.excludedCandidates = new ArrayBitSet(this.maxItemCount);
this.locks = new HashMap<>();
this.exactView = new ExactView();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
/**
* Bitset.
*/
public class ArrayBitSet implements BitSet, Serializable {
public class ArrayBitSet implements Serializable {

private static final long serialVersionUID = 1L;

Expand All @@ -21,10 +21,18 @@ public ArrayBitSet(int count) {
this.buffer = new int[(count >> 5) + 1];
}

/**
* Initializes a new instance of the {@link ArrayBitSet} class. and copies the values
* of another bitset
* @param count The number of items in the set.
*/
public ArrayBitSet(ArrayBitSet other, int count) {
this.buffer = Arrays.copyOf(other.buffer, (count >> 5) + 1);
}

/**
* {@inheritDoc}
*/
@Override
public boolean contains(int id) {
int carrier = this.buffer[id >> 5];
return ((1 << (id & 31)) & carrier) != 0;
Expand All @@ -33,7 +41,6 @@ public boolean contains(int id) {
/**
* {@inheritDoc}
*/
@Override
public void add(int id) {
int mask = 1 << (id & 31);
this.buffer[id >> 5] |= mask;
Expand All @@ -42,7 +49,6 @@ public void add(int id) {
/**
* {@inheritDoc}
*/
@Override
public void remove(int id) {
int mask = 1 << (id & 31);
this.buffer[id >> 5] &= ~mask;
Expand All @@ -51,7 +57,6 @@ public void remove(int id) {
/**
* {@inheritDoc}
*/
@Override
public void clear() {
Arrays.fill(this.buffer, 0);
}
Expand Down
31 changes: 0 additions & 31 deletions hnswlib-core/src/main/java/com/github/jelmerk/knn/util/BitSet.java

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.github.jelmerk.knn.util;

import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

public class ArrayBitSetTest {

@Test
void copyConstructor() {
ArrayBitSet bitset = new ArrayBitSet(100);
bitset.add(50);
ArrayBitSet other = new ArrayBitSet(bitset, 200);
other.add(101);
assertThat(other.contains(50), is(true));
assertThat(other.contains(101), is(true));
}
}
2 changes: 1 addition & 1 deletion hnswlib-examples/hnswlib-examples-java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-examples-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
10 changes: 5 additions & 5 deletions hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Convert(SparkSubmitTask):

app = 'convert.py'

packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']

def requires(self):
return Unzip()
Expand Down Expand Up @@ -109,7 +109,7 @@ class HnswIndex(SparkSubmitTask):

app = 'hnsw_index.py'

packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']

m = IntParameter(default=16)

Expand Down Expand Up @@ -164,7 +164,7 @@ class Query(SparkSubmitTask):

executor_cores = IntParameter(default=2)

packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']

name = 'Query index'

Expand Down Expand Up @@ -230,7 +230,7 @@ class BruteForceIndex(SparkSubmitTask):

app = 'bruteforce_index.py'

packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']

@property
def conf(self):
Expand Down Expand Up @@ -291,7 +291,7 @@ class Evaluate(SparkSubmitTask):

app = 'evaluate_performance.py'

packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']

@property
def conf(self):
Expand Down
2 changes: 1 addition & 1 deletion hnswlib-examples/hnswlib-examples-scala/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-examples-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion hnswlib-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion hnswlib-metrics-dropwizard/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<parent>
<groupId>com.github.jelmerk</groupId>
<artifactId>hnswlib-parent-pom</artifactId>
<version>1.0.1</version>
<version>1.1.0</version>
<relativePath>..</relativePath>
</parent>

Expand Down
14 changes: 7 additions & 7 deletions hnswlib-pyspark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ Find the package appropriate for your spark setup

| | Scala 2.11 | Scala 2.12 |
|-------------|-------------------------------------------------|-------------------------------------------------|
| Spark 2.3.x | com.github.jelmerk:hnswlib-spark_2.3_2.11:1.0.1 | |
| Spark 2.4.x | com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1 | com.github.jelmerk:hnswlib-spark_2.4_2.12:1.0.1 |
| Spark 3.0.x | | com.github.jelmerk:hnswlib-spark_3.0_2.12:1.0.1 |
| Spark 3.1.x | | com.github.jelmerk:hnswlib-spark_3.1_2.12:1.0.1 |
| Spark 3.2.x | | com.github.jelmerk:hnswlib-spark_3.2_2.12:1.0.1 |
| Spark 3.3.x | | com.github.jelmerk:hnswlib-spark_3.3_2.12:1.0.1 |
| Spark 2.3.x | com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0 | |
| Spark 2.4.x | com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0 | com.github.jelmerk:hnswlib-spark_2.4_2.12:1.1.0 |
| Spark 3.0.x | | com.github.jelmerk:hnswlib-spark_3.0_2.12:1.1.0 |
| Spark 3.1.x | | com.github.jelmerk:hnswlib-spark_3.1_2.12:1.1.0 |
| Spark 3.2.x | | com.github.jelmerk:hnswlib-spark_3.2_2.12:1.1.0 |
| Spark 3.3.x | | com.github.jelmerk:hnswlib-spark_3.3_2.12:1.1.0 |


Pass this as an argument to spark

--packages 'com.github.jelmerk:hnswlib-spark_2.3_2.11:1.0.1'
--packages 'com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0'

Then install the python module with

Expand Down
Loading

0 comments on commit 1db0119

Please sign in to comment.