From 793bae014661f8ea3c47380377bd6652f7f203c4 Mon Sep 17 00:00:00 2001 From: Jelmer Kuperus Date: Thu, 28 Nov 2024 07:26:06 +0100 Subject: [PATCH] Redesign spark integration to take advantage of resource profiles. --- .github/workflows/ci.yml | 6 - .github/workflows/publish.yml | 6 - .github/workflows/release.yml | 27 - build.sbt | 91 +- .../quick_start_google_colab.ipynb | 854 ++++++----- .../similarity.ipynb | 8 +- .../hnswlib-examples-pyspark-luigi/.gitignore | 1 + .../bruteforce_index.py | 6 +- .../hnswlib-examples-pyspark-luigi/convert.py | 2 - .../evaluate_performance.py | 6 - .../hnswlib-examples-pyspark-luigi/flow.py | 102 +- .../hnsw_index.py | 7 +- .../hnswlib-examples-pyspark-luigi/query.py | 6 - hnswlib-spark/src/main/protobuf/index.proto | 65 + .../src/main/protobuf/registration.proto | 18 + .../src/main/python/pyspark_hnsw/knn.py | 324 +---- .../client/RegistrationClient.scala | 42 + .../server/DefaultRegistrationService.scala | 29 + .../server/RegistrationServerFactory.scala | 47 + .../serving/client/IndexClientFactory.scala | 197 +++ .../serving/server/DefaultIndexService.scala | 65 + .../serving/server/IndexServerFactory.scala | 84 ++ .../jelmerk/spark/knn/KnnAlgorithm.scala | 1255 +++++++---------- .../jelmerk/spark/knn/QueryIterator.scala | 52 + .../knn/bruteforce/BruteForceSimilarity.scala | 145 +- .../spark/knn/hnsw/HnswSimilarity.scala | 196 +-- .../com/github/jelmerk/spark/knn/knn.scala | 314 ++++- .../src/test/python/test_bruteforce.py | 5 +- hnswlib-spark/src/test/python/test_hnsw.py | 2 +- .../src/test/python/test_integration.py | 4 +- .../spark/knn/hnsw/HnswSimilaritySpec.scala | 52 +- project/plugins.sbt | 3 + 32 files changed, 2263 insertions(+), 1758 deletions(-) delete mode 100644 .github/workflows/release.yml create mode 100644 hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/.gitignore create mode 100644 hnswlib-spark/src/main/protobuf/index.proto create mode 100644 hnswlib-spark/src/main/protobuf/registration.proto create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0196e7ec..485b8984 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,11 +20,6 @@ jobs: fail-fast: false matrix: spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - 3.4.1 - 3.5.0 env: @@ -39,7 +34,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: | - 3.7 3.9 - name: Build and test run: | diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 7c8d74e8..38ae08ff 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -17,11 +17,6 @@ jobs: fail-fast: false matrix: spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - 3.4.1 - 3.5.0 @@ -39,7 +34,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: | - 3.7 3.9 - name: Import GPG Key uses: crazy-max/ghaction-import-gpg@v6 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index 408a4909..00000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Release pipeline -run-name: Release of ${{ inputs.version }} by ${{ github.actor }} - -permissions: - contents: write - -on: - workflow_dispatch: - inputs: - version: - description: Semantic version. For example 1.0.0 - required: true - -jobs: - ci-pipeline: - runs-on: ubuntu-22.04 - steps: - - name: Checkout main branch - uses: actions/checkout@v3 - with: - token: ${{ secrets.RELEASE_TOKEN }} - - name: Release - run: | - git config --global user.email "action@github.com" - git config --global user.name "GitHub Action" - git tag -a v${{ github.event.inputs.version }} -m "next release" - git push --tags diff --git a/build.sbt b/build.sbt index c65413a2..96415a1b 100644 --- a/build.sbt +++ b/build.sbt @@ -1,5 +1,6 @@ -import Path.relativeTo import sys.process.* +import scalapb.compiler.Version.scalapbVersion +import scalapb.compiler.Version.grpcJavaVersion ThisBuild / organization := "com.github.jelmerk" ThisBuild / scalaVersion := "2.12.18" @@ -10,6 +11,8 @@ ThisBuild / dynverSonatypeSnapshots := true ThisBuild / versionScheme := Some("early-semver") +ThisBuild / resolvers += "Local Maven Repository" at "file://" + Path.userHome.absolutePath + "/.m2/repository" + lazy val publishSettings = Seq( pomIncludeRepository := { _ => false }, @@ -40,7 +43,7 @@ lazy val publishSettings = Seq( lazy val noPublishSettings = publish / skip := true -val hnswLibVersion = "1.1.2" +val hnswLibVersion = "1.1.3" val sparkVersion = settingKey[String]("Spark version") val venvFolder = settingKey[String]("Venv folder") val pythonVersion = settingKey[String]("Python version") @@ -52,30 +55,26 @@ lazy val blackCheck = taskKey[Unit]("Run the black code formatter in check lazy val flake8 = taskKey[Unit]("Run the flake8 style enforcer") lazy val root = (project in file(".")) - .aggregate(hnswlibSpark) + .aggregate(uberJar, cosmetic) .settings(noPublishSettings) -lazy val hnswlibSpark = (project in file("hnswlib-spark")) +lazy val uberJar = (project in file("hnswlib-spark")) .settings( - name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}", - publishSettings, - crossScalaVersions := { - if (sparkVersion.value >= "3.2.0") { - Seq("2.12.18", "2.13.10") - } else if (sparkVersion.value >= "3.0.0") { - Seq("2.12.18") - } else { - Seq("2.12.18", "2.11.12") + name := s"hnswlib-spark-uberjar_${sparkVersion.value.split('.').take(2).mkString("_")}", + noPublishSettings, + crossScalaVersions := Seq("2.12.18", "2.13.10"), + autoScalaLibrary := false, + Compile / unmanagedResourceDirectories += baseDirectory.value / "src" / "main" / "python", + Compile / unmanagedResources / includeFilter := { + val pythonSrcDir = baseDirectory.value / "src" / "main" / "python" + (file: File) => { + if (file.getAbsolutePath.startsWith(pythonSrcDir.getAbsolutePath)) file.getName.endsWith(".py") + else true } }, - autoScalaLibrary := false, Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "python", Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "python", - Compile / packageBin / mappings ++= { - val base = baseDirectory.value / "src" / "main" / "python" - val srcs = base ** "*.py" - srcs pair relativeTo(base) - }, + Test / envVars += "SPARK_TESTING" -> "1", Compile / doc / javacOptions ++= { Seq("-Xdoclint:none") }, @@ -83,9 +82,24 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) assembly / assemblyOption ~= { _.withIncludeScala(false) }, - sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"), + assembly / assemblyMergeStrategy := { + case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.first + case x => + val oldStrategy = (ThisBuild / assemblyMergeStrategy).value + oldStrategy(x) + }, + assembly / assemblyShadeRules := Seq( + ShadeRule.rename("google.**" -> "shaded.google.@1").inAll, + ShadeRule.rename("com.google.**" -> "shaded.com.google.@1").inAll, + ShadeRule.rename("io.grpc.**" -> "shaded.io.grpc.@1").inAll, + ShadeRule.rename("io.netty.**" -> "shaded.io.netty.@1").inAll + ), + Compile / PB.targets := Seq( + scalapb.gen() -> (Compile / sourceManaged).value / "scalapb" + ), + sparkVersion := sys.props.getOrElse("sparkVersion", "3.4.1"), venvFolder := s"${baseDirectory.value}/.venv", - pythonVersion := (if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9"), + pythonVersion := "python3.9", createVirtualEnv := { val ret = ( s"${pythonVersion.value} -m venv ${venvFolder.value}" #&& @@ -103,7 +117,7 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) val ret = Process( Seq(s"$venv/bin/pytest", "--junitxml=target/test-reports/TEST-python.xml", "src/test/python"), cwd = baseDirectory.value, - extraEnv = "ARTIFACT_PATH" -> artifactPath, "PYTHONPATH" -> s"${baseDirectory.value}/src/main/python" + extraEnv = "ARTIFACT_PATH" -> artifactPath, "PYTHONPATH" -> s"${baseDirectory.value}/src/main/python", "SPARK_TESTING" -> "1" ).! require(ret == 0, "Python tests failed") } else { @@ -128,12 +142,31 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) }, flake8 := flake8.dependsOn(createVirtualEnv).value, libraryDependencies ++= Seq( - "com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion, - "com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion, - "com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion, - "org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided, - "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided, - "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test, - "org.scalatest" %% "scalatest" % "3.2.17" % Test + "com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion, + "com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion, + "com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion, + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion, + "com.thesamet.scalapb" %% "scalapb-runtime" % scalapbVersion % "protobuf", + "io.grpc" % "grpc-netty" % grpcJavaVersion, + "org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided, + "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided, + "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test, + "org.scalatest" %% "scalatest" % "3.2.17" % Test ) + ) + +// spark cannot resolve artifacts with classifiers so we replace the main artifact +// +// See: https://issues.apache.org/jira/browse/SPARK-20075 +// See: https://github.com/sbt/sbt-assembly/blob/develop/README.md#q-despite-the-concerned-friends-i-still-want-publish-%C3%BCber-jars-what-advice-do-you-have +lazy val cosmetic = project + .settings( + name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}", + Compile / packageBin := (uberJar / assembly).value, + Compile / packageDoc / publishArtifact := false, + Compile / packageSrc / publishArtifact := false, + autoScalaLibrary := false, + crossScalaVersions := Seq("2.12.18", "2.13.10"), + sparkVersion := sys.props.getOrElse("sparkVersion", "3.4.1"), + publishSettings ) \ No newline at end of file diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb b/hnswlib-spark-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb index 31df0359..60849c22 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb @@ -1,447 +1,445 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hnswlib.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# HnswLib Quick Start\n", + "\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jelmerk/hnswlib/blob/master/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb)\n", + "\n", + "We will first set up the runtime environment and give it a quick test" + ], + "metadata": { + "id": "NtnuPdiDyN8_" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "name": "hnswlib.ipynb", - "provenance": [], - "collapsed_sections": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" + "id": "F0u73ufErwpG", + "outputId": "15bde5ea-bdb7-4e23-d74d-75f4ff851fab" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2022-01-08 02:32:40-- https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1269 (1.2K) [text/plain]\n", + "Saving to: ‘STDOUT’\n", + "\n", + "- 100%[===================>] 1.24K --.-KB/s in 0s \n", + "\n", + "2022-01-08 02:32:41 (73.4 MB/s) - written to stdout [1269/1269]\n", + "\n", + "setup Colab for PySpark 3.0.3 and Hnswlib 1.0.0\n", + "Installing PySpark 3.0.3 and Hnswlib 1.0.0\n", + "\u001B[K |████████████████████████████████| 209.1 MB 73 kB/s \n", + "\u001B[K |████████████████████████████████| 198 kB 80.2 MB/s \n", + "\u001B[?25h Building wheel for pyspark (setup.py) ... \u001B[?25l\u001B[?25hdone\n" + ] } + ], + "source": [ + "!wget https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh -O - | bash" + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# HnswLib Quick Start\n", - "\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jelmerk/hnswlib/blob/master/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb)\n", - "\n", - "We will first set up the runtime environment and give it a quick test" - ], - "metadata": { - "id": "NtnuPdiDyN8_" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "F0u73ufErwpG", - "outputId": "15bde5ea-bdb7-4e23-d74d-75f4ff851fab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-01-08 02:32:40-- https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1269 (1.2K) [text/plain]\n", - "Saving to: ‘STDOUT’\n", - "\n", - "- 100%[===================>] 1.24K --.-KB/s in 0s \n", - "\n", - "2022-01-08 02:32:41 (73.4 MB/s) - written to stdout [1269/1269]\n", - "\n", - "setup Colab for PySpark 3.0.3 and Hnswlib 1.0.0\n", - "Installing PySpark 3.0.3 and Hnswlib 1.0.0\n", - "\u001b[K |████████████████████████████████| 209.1 MB 73 kB/s \n", - "\u001b[K |████████████████████████████████| 198 kB 80.2 MB/s \n", - "\u001b[?25h Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!wget https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh -O - | bash" - ] - }, - { - "cell_type": "code", - "source": [ - "import pyspark_hnsw\n", - "\n", - "from pyspark.ml import Pipeline\n", - "from pyspark_hnsw.knn import *\n", - "from pyspark.ml.feature import HashingTF, IDF, Tokenizer\n", - "from pyspark.sql.functions import col, posexplode" - ], - "metadata": { - "id": "nO6TiznusZ2y" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "spark = pyspark_hnsw.start()" - ], - "metadata": { - "id": "Y9KKKcZlscZF" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "print(\"Hnswlib version: {}\".format(pyspark_hnsw.version()))\n", - "print(\"Apache Spark version: {}\".format(spark.version))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CJ2xbiCosydF", - "outputId": "baa771e6-5761-4a4d-fc26-22044aa6aeb5" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Hnswlib version: 1.0.0\n", - "Apache Spark version: 3.0.3\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Load the product data from the [instacart market basket analysis kaggle competition ](https://www.kaggle.com/c/instacart-market-basket-analysis/data?select=products.csv.zip)" - ], - "metadata": { - "id": "nIYBMlF9i6cR" - } - }, - { - "cell_type": "code", - "source": [ - "!wget -O /tmp/products.csv \"https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hOBkUPYO1Zpa", - "outputId": "f003f2ee-bb8c-4b56-a475-a980c992d9da" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-01-08 03:58:45-- https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\n", - "Resolving drive.google.com (drive.google.com)... 173.194.79.100, 173.194.79.102, 173.194.79.101, ...\n", - "Connecting to drive.google.com (drive.google.com)|173.194.79.100|:443... connected.\n", - "HTTP request sent, awaiting response... 302 Moved Temporarily\n", - "Location: https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download [following]\n", - "Warning: wildcards not supported in HTTP.\n", - "--2022-01-08 03:58:45-- https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download\n", - "Resolving doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)... 108.177.127.132, 2a00:1450:4013:c07::84\n", - "Connecting to doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)|108.177.127.132|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 2166953 (2.1M) [text/csv]\n", - "Saving to: ‘/tmp/products.csv’\n", - "\n", - "/tmp/products.csv 100%[===================>] 2.07M --.-KB/s in 0.01s \n", - "\n", - "2022-01-08 03:58:45 (159 MB/s) - ‘/tmp/products.csv’ saved [2166953/2166953]\n", - "\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "productData = spark.read.option(\"header\", \"true\").csv(\"/tmp/products.csv\")" - ], - "metadata": { - "id": "oKodvLC6xwO6" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "productData.count()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q4C7HS1LQDcE", - "outputId": "f0b73205-ae29-4218-bd0e-81eb89fc3c4e" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "49688" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ] - }, - { - "cell_type": "code", - "source": [ - "tokenizer = Tokenizer(inputCol=\"product_name\", outputCol=\"words\")\n", - "hashingTF = HashingTF(inputCol=\"words\", outputCol=\"rawFeatures\")\n", - "idf = IDF(inputCol=\"rawFeatures\", outputCol=\"features\")" - ], - "metadata": { - "id": "Zq2yRJevnRGS" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Create a simple TF / IDF model that turns product names into sparse word vectors and adds them to an exact knn index. \n", - "\n", - "An exact or brute force index will give 100% correct, will be quick to index but really slow to query and is only appropriate during development or for doing comparissons against an approximate index" - ], - "metadata": { - "id": "S3OkoohFo2IA" - } - }, - { - "cell_type": "code", - "source": [ - "bruteforce = BruteForceSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', k = 5, featuresCol='features', distanceFunction='cosine', excludeSelf=True, numPartitions=10)" - ], - "metadata": { - "id": "ReyTZSM1uT2q" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "exact_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, bruteforce])" - ], - "metadata": { - "id": "20wtg6ZhHpwx" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "exact_model = exact_pipeline.fit(productData)" - ], - "metadata": { - "id": "Ln1aIrdyJRoL" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Next create the same model but add the TF / IDF vectors to a HNSW index" - ], - "metadata": { - "id": "cot3ByIOpwwZ" - } - }, - { - "cell_type": "code", - "source": [ - "hnsw = HnswSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', featuresCol='features',\n", - " distanceFunction='cosine', numPartitions=10, excludeSelf=True, k = 5)" - ], - "metadata": { - "id": "7zLQLVreqWRM" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "hnsw_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, hnsw])" - ], - "metadata": { - "id": "mUlvwo89qEJm" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "hnsw_model = hnsw_pipeline.fit(productData)" - ], - "metadata": { - "id": "dwOkEFmxqeR2" - }, - "execution_count": null, - "outputs": [] + { + "cell_type": "code", + "source": [ + "import pyspark_hnsw\n", + "\n", + "from pyspark.ml import Pipeline\n", + "from pyspark_hnsw.knn import *\n", + "from pyspark.ml.feature import HashingTF, IDF, Tokenizer\n", + "from pyspark.sql.functions import col, posexplode" + ], + "metadata": { + "id": "nO6TiznusZ2y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "spark = pyspark_hnsw.start()" + ], + "metadata": { + "id": "Y9KKKcZlscZF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Hnswlib version: {}\".format(pyspark_hnsw.version()))\n", + "print(\"Apache Spark version: {}\".format(spark.version))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "CJ2xbiCosydF", + "outputId": "baa771e6-5761-4a4d-fc26-22044aa6aeb5" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "Select a record to query" - ], - "metadata": { - "id": "MQSYgEgHlg65" - } + "output_type": "stream", + "name": "stdout", + "text": [ + "Hnswlib version: 1.0.0\n", + "Apache Spark version: 3.0.3\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Load the product data from the [instacart market basket analysis kaggle competition ](https://www.kaggle.com/c/instacart-market-basket-analysis/data?select=products.csv.zip)" + ], + "metadata": { + "id": "nIYBMlF9i6cR" + } + }, + { + "cell_type": "code", + "source": [ + "!wget -O /tmp/products.csv \"https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\"" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "hOBkUPYO1Zpa", + "outputId": "f003f2ee-bb8c-4b56-a475-a980c992d9da" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "queries = productData.filter(col(\"product_id\") == 43572)" - ], - "metadata": { - "id": "vCag3tH-NUf-" - }, - "execution_count": null, - "outputs": [] + "output_type": "stream", + "name": "stdout", + "text": [ + "--2022-01-08 03:58:45-- https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\n", + "Resolving drive.google.com (drive.google.com)... 173.194.79.100, 173.194.79.102, 173.194.79.101, ...\n", + "Connecting to drive.google.com (drive.google.com)|173.194.79.100|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Moved Temporarily\n", + "Location: https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download [following]\n", + "Warning: wildcards not supported in HTTP.\n", + "--2022-01-08 03:58:45-- https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download\n", + "Resolving doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)... 108.177.127.132, 2a00:1450:4013:c07::84\n", + "Connecting to doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)|108.177.127.132|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 2166953 (2.1M) [text/csv]\n", + "Saving to: ‘/tmp/products.csv’\n", + "\n", + "/tmp/products.csv 100%[===================>] 2.07M --.-KB/s in 0.01s \n", + "\n", + "2022-01-08 03:58:45 (159 MB/s) - ‘/tmp/products.csv’ saved [2166953/2166953]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "productData = spark.read.option(\"header\", \"true\").csv(\"/tmp/products.csv\")" + ], + "metadata": { + "id": "oKodvLC6xwO6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "productData.count()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "q4C7HS1LQDcE", + "outputId": "f0b73205-ae29-4218-bd0e-81eb89fc3c4e" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "queries.show(truncate = False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pcUCCFxzQ02H", - "outputId": "8721ba75-f5d2-493e-a36c-d182e97a3bd0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+-----------------------------+--------+-------------+\n", - "|product_id|product_name |aisle_id|department_id|\n", - "+----------+-----------------------------+--------+-------------+\n", - "|43572 |Alcaparrado Manzanilla Olives|110 |13 |\n", - "+----------+-----------------------------+--------+-------------+\n", - "\n" - ] - } + "output_type": "execute_result", + "data": { + "text/plain": [ + "49688" ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "source": [ + "tokenizer = Tokenizer(inputCol=\"product_name\", outputCol=\"words\")\n", + "hashingTF = HashingTF(inputCol=\"words\", outputCol=\"rawFeatures\")\n", + "idf = IDF(inputCol=\"rawFeatures\", outputCol=\"features\")" + ], + "metadata": { + "id": "Zq2yRJevnRGS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Create a simple TF / IDF model that turns product names into sparse word vectors and adds them to an exact knn index. \n", + "\n", + "An exact or brute force index will give 100% correct, will be quick to index but really slow to query and is only appropriate during development or for doing comparissons against an approximate index" + ], + "metadata": { + "id": "S3OkoohFo2IA" + } + }, + { + "cell_type": "code", + "source": "bruteforce = BruteForceSimilarity(identifierCol='product_id', k = 5, featuresCol='features', distanceFunction='cosine', numPartitions=10)", + "metadata": { + "id": "ReyTZSM1uT2q" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "exact_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, bruteforce])" + ], + "metadata": { + "id": "20wtg6ZhHpwx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "exact_model = exact_pipeline.fit(productData)" + ], + "metadata": { + "id": "Ln1aIrdyJRoL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next create the same model but add the TF / IDF vectors to a HNSW index" + ], + "metadata": { + "id": "cot3ByIOpwwZ" + } + }, + { + "cell_type": "code", + "source": [ + "hnsw = HnswSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', featuresCol='features',\n", + " distanceFunction='cosine', numPartitions=10, excludeSelf=True, k = 5)" + ], + "metadata": { + "id": "7zLQLVreqWRM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "hnsw_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, hnsw])" + ], + "metadata": { + "id": "mUlvwo89qEJm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "hnsw_model = hnsw_pipeline.fit(productData)" + ], + "metadata": { + "id": "dwOkEFmxqeR2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Select a record to query" + ], + "metadata": { + "id": "MQSYgEgHlg65" + } + }, + { + "cell_type": "code", + "source": [ + "queries = productData.filter(col(\"product_id\") == 43572)" + ], + "metadata": { + "id": "vCag3tH-NUf-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "queries.show(truncate = False)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "pcUCCFxzQ02H", + "outputId": "8721ba75-f5d2-493e-a36c-d182e97a3bd0" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "Show the results from the exact model" - ], - "metadata": { - "id": "qbcUGq4irTFH" - } - }, - { - "cell_type": "code", - "source": [ - "exact_model.transform(queries) \\\n", - " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", - " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", - " .join(productData, [\"product_id\"]) \\\n", - " .show(truncate=False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q4wi29adOLRX", - "outputId": "1b06735b-8db4-4c4f-fe16-7d1aad02ea6d" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|product_id|pos|distance |product_name |aisle_id|department_id|\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", - "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", - "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", - "|39833 |3 |0.49516580877903393|Pimiento Sliced Manzanilla Olives |110 |13 |\n", - "|33495 |4 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "\n" - ] - } - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "+----------+-----------------------------+--------+-------------+\n", + "|product_id|product_name |aisle_id|department_id|\n", + "+----------+-----------------------------+--------+-------------+\n", + "|43572 |Alcaparrado Manzanilla Olives|110 |13 |\n", + "+----------+-----------------------------+--------+-------------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Show the results from the exact model" + ], + "metadata": { + "id": "qbcUGq4irTFH" + } + }, + { + "cell_type": "code", + "source": [ + "exact_model.transform(queries) \\\n", + " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", + " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", + " .join(productData, [\"product_id\"]) \\\n", + " .show(truncate=False)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "q4wi29adOLRX", + "outputId": "1b06735b-8db4-4c4f-fe16-7d1aad02ea6d" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "Show the results from the hnsw model" - ], - "metadata": { - "id": "JxHQ10aAr0MQ" - } + "output_type": "stream", + "name": "stdout", + "text": [ + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "|product_id|pos|distance |product_name |aisle_id|department_id|\n", + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", + "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", + "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", + "|39833 |3 |0.49516580877903393|Pimiento Sliced Manzanilla Olives |110 |13 |\n", + "|33495 |4 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Show the results from the hnsw model" + ], + "metadata": { + "id": "JxHQ10aAr0MQ" + } + }, + { + "cell_type": "code", + "source": [ + "hnsw_model.transform(queries) \\\n", + " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", + " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", + " .join(productData, [\"product_id\"]) \\\n", + " .show(truncate=False)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "PupolEF6P0jc", + "outputId": "9c0ce36d-32ae-4494-d277-6a7246a17588" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "hnsw_model.transform(queries) \\\n", - " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", - " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", - " .join(productData, [\"product_id\"]) \\\n", - " .show(truncate=False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PupolEF6P0jc", - "outputId": "9c0ce36d-32ae-4494-d277-6a7246a17588" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|product_id|pos|distance |product_name |aisle_id|department_id|\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", - "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", - "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", - "|33495 |3 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", - "|41472 |4 |0.514201828085252 |Pimiento Stuffed Manzanilla Olives|110 |13 |\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "\n" - ] - } - ] + "output_type": "stream", + "name": "stdout", + "text": [ + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "|product_id|pos|distance |product_name |aisle_id|department_id|\n", + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", + "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", + "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", + "|33495 |3 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", + "|41472 |4 |0.514201828085252 |Pimiento Stuffed Manzanilla Olives|110 |13 |\n", + "+----------+---+-------------------+----------------------------------+--------+-------------+\n", + "\n" + ] } - ] -} \ No newline at end of file + ] + } + ] +} diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb b/hnswlib-spark-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb index ff2d6de4..53f17e1b 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb @@ -37,7 +37,9 @@ "from pyspark.ml.feature import VectorAssembler\n", "from pyspark_hnsw.conversion import VectorConverter\n", "from pyspark_hnsw.knn import *\n", - "from pyspark_hnsw.linalg import Normalizer" + "from pyspark_hnsw.linalg import Normalizer\n", + "\n", + "import multiprocessing" ] }, { @@ -408,9 +410,9 @@ "\n", "normalizer = Normalizer(inputCol='features', outputCol='normalized_features')\n", "\n", - "hnsw = HnswSimilarity(identifierCol='id', queryIdentifierCol='id', featuresCol='normalized_features', \n", + "hnsw = HnswSimilarity(identifierCol='id', featuresCol='normalized_features',\n", " distanceFunction='inner-product', m=48, ef=5, k=10, efConstruction=200, numPartitions=2, \n", - " excludeSelf=True, predictionCol='approximate', outputFormat='minimal')\n", + " numThreads=multiprocessing.cpu_count(), predictionCol='approximate')\n", " \n", "pipeline = Pipeline(stages=[vector_assembler, converter, normalizer, hnsw])\n", "\n", diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/.gitignore b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/.gitignore new file mode 100644 index 00000000..75b5abd0 --- /dev/null +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/.gitignore @@ -0,0 +1 @@ +.luigi-venv/ diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py index 94848e57..43c7996a 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from pyspark.ml import Pipeline @@ -14,13 +12,15 @@ def main(spark): parser.add_argument('--model', type=str) parser.add_argument('--output', type=str) parser.add_argument('--num_partitions', type=int) + parser.add_argument('--num_threads', type=int) args = parser.parse_args() normalizer = Normalizer(inputCol='features', outputCol='normalized_features') bruteforce = BruteForceSimilarity(identifierCol='id', featuresCol='normalized_features', - distanceFunction='inner-product', numPartitions=args.num_partitions) + distanceFunction='inner-product', numPartitions=args.num_partitions, + numThreads=args.num_threads) pipeline = Pipeline(stages=[normalizer, bruteforce]) diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/convert.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/convert.py index f855f3c0..69855ef4 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/convert.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/convert.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from pyspark.ml.feature import VectorAssembler diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py index d90aaeb6..06eb3fe3 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from pyspark.ml import PipelineModel @@ -14,7 +12,6 @@ def main(spark): parser.add_argument('--input', type=str) parser.add_argument('--output', type=str) parser.add_argument('--k', type=int) - parser.add_argument('--ef', type=int) parser.add_argument('--fraction', type=float) parser.add_argument('--seed', type=int) @@ -25,17 +22,14 @@ def main(spark): hnsw_model = PipelineModel.read().load(args.hnsw_model) hnsw_stage = hnsw_model.stages[-1] - hnsw_stage.setEf(args.ef) hnsw_stage.setK(args.k) hnsw_stage.setPredictionCol('approximate') - hnsw_stage.setOutputFormat('full') bruteforce_model = PipelineModel.read().load(args.bruteforce_model) bruteforce_stage = bruteforce_model.stages[-1] bruteforce_stage.setK(args.k) bruteforce_stage.setPredictionCol('exact') - bruteforce_stage.setOutputFormat('full') sample_results = bruteforce_model.transform(hnsw_model.transform(sample_query_items)) diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/flow.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/flow.py index cd4cbe7e..3f80952e 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/flow.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/flow.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- - import urllib.request import shutil import luigi +import multiprocessing from luigi import FloatParameter, IntParameter, LocalTarget, Parameter from luigi.contrib.spark import SparkSubmitTask from luigi.format import Nop @@ -11,6 +10,9 @@ # from luigi.contrib.hdfs import HdfsFlagTarget # from luigi.contrib.s3 import S3FlagTarget +JAR='/Users/jelmerkuperus/dev/3rdparty/hnswlib-spark-old/hnswlib-spark/target/scala-2.12/hnswlib-spark-uberjar_3_4-assembly-1.1.2+6-0a591e8c-SNAPSHOT.jar' + +num_cores=multiprocessing.cpu_count(), class Download(luigi.Task): """ @@ -63,13 +65,14 @@ class Convert(SparkSubmitTask): # executor_memory = '4g' - num_executors = IntParameter(default=2) + num_executors = IntParameter(default=1) name = 'Convert' app = 'convert.py' - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + # packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + jars = [JAR] def requires(self): return Unzip() @@ -101,15 +104,18 @@ class HnswIndex(SparkSubmitTask): # executor_memory = '12g' - num_executors = IntParameter(default=2) + num_executors = IntParameter(default=1) - executor_cores = IntParameter(default=2) + executor_cores = IntParameter(default=num_cores) name = 'Hnsw index' app = 'hnsw_index.py' - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + env = { "SPARK_TESTING": "1" } # needs to be set when using local spark + + # packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + jars = [JAR] m = IntParameter(default=16) @@ -117,15 +123,8 @@ class HnswIndex(SparkSubmitTask): @property def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s', - 'spark.hnswlib.settings.index.cache_folder': '/tmp'} + return {'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', + 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator'} def requires(self): return Convert() @@ -136,7 +135,8 @@ def app_options(self): '--output', self.output().path, '--m', self.m, '--ef_construction', self.ef_construction, - '--num_partitions', str(self.num_executors) + '--num_partitions', 1, + '--num_threads', num_cores ] def output(self): @@ -160,11 +160,12 @@ class Query(SparkSubmitTask): # executor_memory = '10g' - num_executors = IntParameter(default=4) + num_executors = IntParameter(default=1) - executor_cores = IntParameter(default=2) + executor_cores = IntParameter(default=num_cores) - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + # packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + jars = [JAR] name = 'Query index' @@ -172,20 +173,10 @@ class Query(SparkSubmitTask): k = IntParameter(default=10) - ef = IntParameter(default=100) - - num_replicas = IntParameter(default=1) - @property def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s'} + return {'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', + 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator'} def requires(self): return {'vectors': Convert(), @@ -196,9 +187,7 @@ def app_options(self): '--input', self.input()['vectors'].path, '--model', self.input()['index'].path, '--output', self.output().path, - '--ef', self.ef, - '--k', self.k, - '--num_replicas', self.num_replicas + '--k', self.k ] def output(self): @@ -222,27 +211,23 @@ class BruteForceIndex(SparkSubmitTask): # executor_memory = '12g' - num_executors = IntParameter(default=2) + num_executors = IntParameter(default=1) - executor_cores = IntParameter(default=2) + executor_cores = IntParameter(default=num_cores) name = 'Brute force index' app = 'bruteforce_index.py' - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + # packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + jars = [JAR] + + env = { "SPARK_TESTING": "1" } # needs to be set when using local spark @property def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s', - 'spark.hnswlib.settings.index.cache_folder': '/tmp'} + return {'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', + 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator'} def requires(self): return Convert() @@ -251,7 +236,8 @@ def app_options(self): return [ '--input', self.input().path, '--output', self.output().path, - '--num_partitions', str(self.num_executors) + '--num_partitions', 1, + '--num_threads', num_cores ] def output(self): @@ -275,14 +261,12 @@ class Evaluate(SparkSubmitTask): # executor_memory = '12g' - num_executors = IntParameter(default=2) + num_executors = IntParameter(default=1) - executor_cores = IntParameter(default=2) + executor_cores = IntParameter(default=num_cores) k = IntParameter(default=10) - ef = IntParameter(default=100) - fraction = FloatParameter(default=0.0001) seed = IntParameter(default=123) @@ -291,18 +275,13 @@ class Evaluate(SparkSubmitTask): app = 'evaluate_performance.py' - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + # packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] + jars = [JAR] @property def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s'} + return {'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', + 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator'} def requires(self): return {'vectors': Convert(), @@ -315,7 +294,6 @@ def app_options(self): '--output', self.output().path, '--hnsw_model', self.input()['hnsw_index'].path, '--bruteforce_model', self.input()['bruteforce_index'].path, - '--ef', self.ef, '--k', self.k, '--seed', self.seed, '--fraction', self.fraction, diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py index 5a1df028..289bfaa0 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from pyspark.ml import Pipeline @@ -15,14 +13,15 @@ def main(spark): parser.add_argument('--m', type=int) parser.add_argument('--ef_construction', type=int) parser.add_argument('--num_partitions', type=int) + parser.add_argument('--num_threads', type=int) args = parser.parse_args() normalizer = Normalizer(inputCol='features', outputCol='normalized_features') - hnsw = HnswSimilarity(identifierCol='id', queryIdentifierCol='id', featuresCol='normalized_features', + hnsw = HnswSimilarity(identifierCol='id', featuresCol='normalized_features', distanceFunction='inner-product', m=args.m, efConstruction=args.ef_construction, - numPartitions=args.num_partitions, excludeSelf=True, outputFormat='minimal') + numPartitions=args.num_partitions, numThreads=args.num_threads) pipeline = Pipeline(stages=[normalizer, hnsw]) diff --git a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/query.py b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/query.py index fbc5a859..689d74ca 100644 --- a/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/query.py +++ b/hnswlib-spark-examples/hnswlib-examples-pyspark-luigi/query.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from pyspark.ml import PipelineModel @@ -12,17 +10,13 @@ def main(spark): parser.add_argument('--model', type=str) parser.add_argument('--output', type=str) parser.add_argument('--k', type=int) - parser.add_argument('--ef', type=int) - parser.add_argument('--num_replicas', type=int) args = parser.parse_args() model = PipelineModel.read().load(args.model) hnsw_stage = model.stages[-1] - hnsw_stage.setEf(args.ef) hnsw_stage.setK(args.k) - hnsw_stage.setNumReplicas(args.num_replicas) query_items = spark.read.parquet(args.input) diff --git a/hnswlib-spark/src/main/protobuf/index.proto b/hnswlib-spark/src/main/protobuf/index.proto new file mode 100644 index 00000000..c81fb2f8 --- /dev/null +++ b/hnswlib-spark/src/main/protobuf/index.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; + +import "scalapb/scalapb.proto"; + +package com.github.jelmerk.server; + +service IndexService { + rpc Search (stream SearchRequest) returns (stream SearchResponse) {} + rpc SaveIndex (SaveIndexRequest) returns (SaveIndexResponse) {} +} + +message SearchRequest { + + oneof vector { + FloatArrayVector float_array_vector = 4; + DoubleArrayVector double_array_vector = 5; + SparseVector sparse_vector = 6; + DenseVector dense_vector = 7; + } + + int32 k = 8; +} + +message SearchResponse { + repeated Result results = 1; +} + +message Result { + oneof id { + string string_id = 1; + int64 long_id = 2; + int32 int_id = 3; + } + + oneof distance { + float float_distance = 4; + double double_distance = 5; + } +} + +message FloatArrayVector { + repeated float values = 1 [(scalapb.field).collection_type="Array"]; +} + +message DoubleArrayVector { + repeated double values = 1 [(scalapb.field).collection_type="Array"]; +} + +message SparseVector { + int32 size = 1; + repeated int32 indices = 2 [(scalapb.field).collection_type="Array"]; + repeated double values = 3 [(scalapb.field).collection_type="Array"]; +} + +message DenseVector { + repeated double values = 1 [(scalapb.field).collection_type="Array"]; +} + +message SaveIndexRequest { + string path = 1; +} + +message SaveIndexResponse { + int64 bytes_written = 1; +} \ No newline at end of file diff --git a/hnswlib-spark/src/main/protobuf/registration.proto b/hnswlib-spark/src/main/protobuf/registration.proto new file mode 100644 index 00000000..ced9439c --- /dev/null +++ b/hnswlib-spark/src/main/protobuf/registration.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package com.github.jelmerk.server; + +service RegistrationService { + rpc Register (RegisterRequest) returns ( RegisterResponse) {} +} + +message RegisterRequest { + int32 partition_num = 1; + int32 replica_num = 2; + string host = 3; + int32 port = 4; +} + +message RegisterResponse { +} + diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py b/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py index d1ac2f19..81f47216 100644 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py +++ b/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import ( Params, @@ -10,8 +12,11 @@ # noinspection PyProtectedMember from pyspark import keyword_only + +# noinspection PyProtectedMember from pyspark.ml.util import JavaMLReadable, JavaMLWritable, MLReader, _jvm + __all__ = [ "HnswSimilarity", "HnswSimilarityModel", @@ -21,6 +26,12 @@ ] +class KnnModel(JavaModel): + def destroy(self): + assert self._java_obj is not None + return self._java_obj.destroy() + + class HnswLibMLReader(MLReader): """ Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types @@ -31,6 +42,7 @@ def __init__(self, clazz, java_class): self._clazz = clazz self._jread = self._load_java_obj(java_class).read() + # noinspection PyProtectedMember def load(self, path): """Load the ML instance from the input path.""" java_obj = self._jread.load(path) @@ -52,13 +64,6 @@ class _KnnModelParams(HasFeaturesCol, HasPredictionCol): Params for knn models. """ - queryIdentifierCol = Param( - Params._dummy(), - "queryIdentifierCol", - "the column name for the query identifier", - typeConverter=TypeConverters.toString, - ) - queryPartitionsCol = Param( Params._dummy(), "queryPartitionsCol", @@ -66,13 +71,6 @@ class _KnnModelParams(HasFeaturesCol, HasPredictionCol): typeConverter=TypeConverters.toString, ) - parallelism = Param( - Params._dummy(), - "parallelism", - "number of threads to use", - typeConverter=TypeConverters.toInt, - ) - k = Param( Params._dummy(), "k", @@ -80,82 +78,18 @@ class _KnnModelParams(HasFeaturesCol, HasPredictionCol): typeConverter=TypeConverters.toInt, ) - numReplicas = Param( - Params._dummy(), - "numReplicas", - "number of index replicas to create when querying", - typeConverter=TypeConverters.toInt, - ) - - excludeSelf = Param( - Params._dummy(), - "excludeSelf", - "whether to include the row identifier as a candidate neighbor", - typeConverter=TypeConverters.toBoolean, - ) - - similarityThreshold = Param( - Params._dummy(), - "similarityThreshold", - "do not return neighbors further away than this distance", - typeConverter=TypeConverters.toFloat, - ) - - outputFormat = Param( - Params._dummy(), - "outputFormat", - "output format, one of full, minimal", - typeConverter=TypeConverters.toString, - ) - - def getQueryIdentifierCol(self): - """ - Gets the value of queryIdentifierCol or its default value. - """ - return self.getOrDefault(self.queryIdentifierCol) - def getQueryPartitionsCol(self): """ Gets the value of queryPartitionsCol or its default value. """ return self.getOrDefault(self.queryPartitionsCol) - def getParallelism(self): - """ - Gets the value of parallelism or its default value. - """ - return self.getOrDefault(self.parallelism) - def getK(self): """ Gets the value of k or its default value. """ return self.getOrDefault(self.k) - def getExcludeSelf(self): - """ - Gets the value of excludeSelf or its default value. - """ - return self.getOrDefault(self.excludeSelf) - - def getSimilarityThreshold(self): - """ - Gets the value of similarityThreshold or its default value. - """ - return self.getOrDefault(self.similarityThreshold) - - def getOutputFormat(self): - """ - Gets the value of outputFormat or its default value. - """ - return self.getOrDefault(self.outputFormat) - - def getNumReplicas(self): - """ - Gets the value of numReplicas or its default value. - """ - return self.getOrDefault(self.numReplicas) - # noinspection PyPep8Naming @inherit_doc @@ -164,6 +98,20 @@ class _KnnParams(_KnnModelParams): Params for knn algorithms. """ + numThreads = Param( + Params._dummy(), + "numThreads", + "number of threads to use", + typeConverter=TypeConverters.toInt, + ) + + numReplicas = Param( + Params._dummy(), + "numReplicas", + "number of index replicas to create when querying", + typeConverter=TypeConverters.toInt, + ) + identifierCol = Param( Params._dummy(), "identifierCol", @@ -201,6 +149,18 @@ class _KnnParams(_KnnModelParams): typeConverter=TypeConverters.toString, ) + def getNumThreads(self) -> int: + """ + Gets the value of numThreads. + """ + return self.getOrDefault(self.numThreads) + + def getNumReplicas(self) -> int: + """ + Gets the value of numReplicas or its default value. + """ + return self.getOrDefault(self.numReplicas) + def getIdentifierCol(self): """ Gets the value of identifierCol or its default value. @@ -239,19 +199,6 @@ class _HnswModelParams(_KnnModelParams): Params for :py:class:`Hnsw` and :py:class:`HnswModel`. """ - ef = Param( - Params._dummy(), - "ef", - "size of the dynamic list for the nearest neighbors (used during the search)", - typeConverter=TypeConverters.toInt, - ) - - def getEf(self): - """ - Gets the value of ef or its default value. - """ - return self.getOrDefault(self.ef) - # noinspection PyPep8Naming @inherit_doc @@ -260,6 +207,13 @@ class _HnswParams(_HnswModelParams, _KnnParams): Params for :py:class:`Hnsw`. """ + ef = Param( + Params._dummy(), + "ef", + "size of the dynamic list for the nearest neighbors (used during the search)", + typeConverter=TypeConverters.toInt, + ) + m = Param( Params._dummy(), "m", @@ -274,6 +228,12 @@ class _HnswParams(_HnswModelParams, _KnnParams): typeConverter=TypeConverters.toInt, ) + def getEf(self): + """ + Gets the value of ef or its default value. + """ + return self.getOrDefault(self.ef) + def getM(self): """ Gets the value of m or its default value. @@ -294,23 +254,22 @@ class BruteForceSimilarity(JavaEstimator, _KnnParams, JavaMLReadable, JavaMLWrit Exact nearest neighbour search. """ + _input_kwargs: Dict[str, Any] + + # noinspection PyUnusedLocal @keyword_only def __init__( self, identifierCol="id", partitionCol=None, - queryIdentifierCol=None, queryPartitionsCol=None, - parallelism=None, + numThreads=None, featuresCol="features", predictionCol="prediction", - numPartitions=1, + numPartitions=None, numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None, ): super(BruteForceSimilarity, self).__init__() @@ -324,9 +283,6 @@ def __init__( numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", ) kwargs = self._input_kwargs @@ -338,12 +294,6 @@ def setIdentifierCol(self, value): """ return self._set(identifierCol=value) - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - def setPartitionCol(self, value): """ Sets the value of :py:attr:`partitionCol`. @@ -356,11 +306,11 @@ def setQueryPartitionsCol(self, value): """ return self._set(queryPartitionsCol=value) - def setParallelism(self, value): + def setNumThreads(self, value): """ - Sets the value of :py:attr:`parallelism`. + Sets the value of :py:attr:`numThreads`. """ - return self._set(parallelism=value) + return self._set(numThreads=value) def setNumPartitions(self, value): """ @@ -386,46 +336,25 @@ def setDistanceFunction(self, value): """ return self._set(distanceFunction=value) - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - def setInitialModelPath(self, value): """ Sets the value of :py:attr:`initialModelPath`. """ return self._set(initialModelPath=value) + # noinspection PyUnusedLocal @keyword_only def setParams( self, identifierCol="id", - queryIdentifierCol=None, queryPartitionsCol=None, - parallelism=None, + numThreads=None, featuresCol="features", predictionCol="prediction", numPartitions=1, numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None, ): kwargs = self._input_kwargs @@ -437,7 +366,7 @@ def _create_model(self, java_model): # noinspection PyPep8Naming class BruteForceSimilarityModel( - JavaModel, _KnnModelParams, JavaMLReadable, JavaMLWritable + KnnModel, _KnnModelParams, JavaMLReadable, JavaMLWritable ): """ Model fitted by BruteForce. @@ -447,54 +376,18 @@ class BruteForceSimilarityModel( "com.github.jelmerk.spark.knn.bruteforce.BruteForceSimilarityModel" ) - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - def setQueryPartitionsCol(self, value): """ Sets the value of :py:attr:`queryPartitionsCol`. """ return self._set(queryPartitionsCol=value) - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - def setK(self, value): """ Sets the value of :py:attr:`k`. """ return self._set(k=value) - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - @classmethod def read(cls): return HnswLibMLReader(cls, cls._classpath_model) @@ -507,13 +400,13 @@ class HnswSimilarity(JavaEstimator, _HnswParams, JavaMLReadable, JavaMLWritable) Approximate nearest neighbour search. """ + # noinspection PyUnusedLocal @keyword_only def __init__( self, identifierCol="id", - queryIdentifierCol=None, queryPartitionsCol=None, - parallelism=None, + numThreads=None, featuresCol="features", predictionCol="prediction", m=16, @@ -523,9 +416,6 @@ def __init__( numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None, ): super(HnswSimilarity, self).__init__() @@ -542,9 +432,6 @@ def __init__( numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None, ) @@ -557,12 +444,6 @@ def setIdentifierCol(self, value): """ return self._set(identifierCol=value) - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - def setPartitionCol(self, value): """ Sets the value of :py:attr:`partitionCol`. @@ -575,11 +456,11 @@ def setQueryPartitionsCol(self, value): """ return self._set(queryPartitionsCol=value) - def setParallelism(self, value): + def setNumThreads(self, value): """ - Sets the value of :py:attr:`parallelism`. + Sets the value of :py:attr:`numThreads`. """ - return self._set(parallelism=value) + return self._set(numThreads=value) def setNumPartitions(self, value): """ @@ -605,24 +486,6 @@ def setDistanceFunction(self, value): """ return self._set(distanceFunction=value) - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - def setM(self, value): """ Sets the value of :py:attr:`m`. @@ -647,25 +510,22 @@ def setInitialModelPath(self, value): """ return self._set(initialModelPath=value) + # noinspection PyUnusedLocal @keyword_only def setParams( self, identifierCol="id", - queryIdentifierCol=None, queryPartitionsCol=None, - parallelism=None, + numThreads=None, featuresCol="features", predictionCol="prediction", m=16, ef=10, efConstruction=200, - numPartitions=1, + numPartitions=None, numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, - similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None, ): kwargs = self._input_kwargs @@ -676,67 +536,25 @@ def _create_model(self, java_model): # noinspection PyPep8Naming -class HnswSimilarityModel(JavaModel, _HnswModelParams, JavaMLReadable, JavaMLWritable): +class HnswSimilarityModel(KnnModel, _HnswModelParams, JavaMLReadable, JavaMLWritable): """ Model fitted by Hnsw. """ _classpath_model = "com.github.jelmerk.spark.knn.hnsw.HnswSimilarityModel" - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - def setQueryPartitionsCol(self, value): """ Sets the value of :py:attr:`queryPartitionsCol`. """ return self._set(queryPartitionsCol=value) - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - def setK(self, value): """ Sets the value of :py:attr:`k`. """ return self._set(k=value) - def setEf(self, value): - """ - Sets the value of :py:attr:`ef`. - """ - return self._set(ef=value) - - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - @classmethod def read(cls): return HnswLibMLReader(cls, cls._classpath_model) diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala new file mode 100644 index 00000000..b08219eb --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala @@ -0,0 +1,42 @@ +package com.github.jelmerk.registration.client + +import java.net.{InetSocketAddress, SocketAddress} + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc} +import io.grpc.netty.NettyChannelBuilder + +object RegistrationClient { + + def register( + server: SocketAddress, + partitionNo: Int, + replicaNo: Int, + indexServerAddress: InetSocketAddress + ): RegisterResponse = { + val channel = NettyChannelBuilder + .forAddress(server) + .usePlaintext + .build() + + try { + val client = RegistrationServiceGrpc.stub(channel) + + val request = RegisterRequest( + partitionNum = partitionNo, + replicaNum = replicaNo, + indexServerAddress.getHostName, + indexServerAddress.getPort + ) + + val response = client.register(request) + + Await.result(response, Duration.Inf) + } finally { + channel.shutdownNow() + } + + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala new file mode 100644 index 00000000..c766b150 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala @@ -0,0 +1,29 @@ +package com.github.jelmerk.registration.server + +import java.net.InetSocketAddress +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} + +import scala.concurrent.Future + +import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse} +import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService + +class DefaultRegistrationService(val registrationLatch: CountDownLatch) extends RegistrationService { + + val registrations = new ConcurrentHashMap[PartitionAndReplica, InetSocketAddress]() + + override def register(request: RegisterRequest): Future[RegisterResponse] = { + + val key = PartitionAndReplica(request.partitionNum, request.replicaNum) + val previousValue = registrations.put(key, new InetSocketAddress(request.host, request.port)) + + if (previousValue == null) { + registrationLatch.countDown() + } + + Future.successful(RegisterResponse()) + } + +} + +case class PartitionAndReplica(partitionNum: Int, replicaNum: Int) diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala new file mode 100644 index 00000000..484d5192 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala @@ -0,0 +1,47 @@ +package com.github.jelmerk.registration.server + +import java.net.InetSocketAddress +import java.util.concurrent.{CountDownLatch, Executors} + +import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters._ +import scala.util.Try + +import com.github.jelmerk.server.registration.RegistrationServiceGrpc +import io.grpc.netty.NettyServerBuilder + +class RegistrationServer(host: String, numPartitions: Int, numReplicas: Int) { + + private val executor = Executors.newSingleThreadExecutor() + + private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) + + private val registrationLatch = new CountDownLatch(numPartitions + (numReplicas * numPartitions)) + private val service = new DefaultRegistrationService(registrationLatch) + + private val server = NettyServerBuilder + .forAddress(new InetSocketAddress(host, 0)) + .addService(RegistrationServiceGrpc.bindService(service, executionContext)) + .build() + + def start(): Unit = server.start() + + def address: InetSocketAddress = server.getListenSockets.get(0).asInstanceOf[InetSocketAddress] // TODO CLEANUP + + def awaitRegistrations(): Map[PartitionAndReplica, InetSocketAddress] = { + service.registrationLatch.await() + service.registrations.asScala.toMap + } + + def shutdown(): Unit = { + Try(server.shutdown()) + Try(executor.shutdown()) + } + +} + +object RegistrationServerFactory { + + def create(host: String, numPartitions: Int, numReplicas: Int): RegistrationServer = + new RegistrationServer(host, numPartitions, numReplicas) +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala new file mode 100644 index 00000000..50178f6f --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala @@ -0,0 +1,197 @@ +package com.github.jelmerk.serving.client + +import java.net.InetSocketAddress +import java.util.concurrent.{Executors, LinkedBlockingQueue} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration +import scala.language.implicitConversions +import scala.util.Random + +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.server.index._ +import io.grpc.netty.NettyChannelBuilder +import io.grpc.stub.StreamObserver +import org.apache.spark.sql.Row + +class IndexClient[TId, TVector, TDistance]( + indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + vectorConverter: TVector => SearchRequest.Vector, + idExtractor: Result => TId, + distanceExtractor: Result => TDistance, + distanceOrdering: Ordering[TDistance] +) { + + private val random = new Random() + + private val (channels, grpcClients) = indexAddresses.map { case (key, address) => + val channel = NettyChannelBuilder + .forAddress(address) + .usePlaintext + .build() + + (channel, (key, IndexServiceGrpc.stub(channel))) + }.unzip + + private val partitionClients = grpcClients.toList + .sortBy { case (partitionAndReplica, _) => (partitionAndReplica.partitionNum, partitionAndReplica.replicaNum) } + .foldLeft(Map.empty[Int, Seq[IndexServiceGrpc.IndexServiceStub]]) { + case (acc, (PartitionAndReplica(partitionNum, _), client)) => + val old = acc.getOrElse(partitionNum, Seq.empty[IndexServiceGrpc.IndexServiceStub]) + acc.updated(partitionNum, old :+ client) + } + + private val allPartitions = indexAddresses.map(_._1.partitionNum).toSeq.distinct // TODO not very nice + + private val threadPool = Executors.newFixedThreadPool(1) + + def search( + vectorColumn: String, + queryPartitionsColumn: Option[String], + batch: Seq[Row], + k: Int + ): Iterator[Row] = { + + val queries = batch.map { row => + val partitions = queryPartitionsColumn.fold(allPartitions) { name => row.getAs[Seq[Int]](name) } + val vector = row.getAs[TVector](vectorColumn) + (partitions, vector) + } + + // TODO should i use a random client or a client + val randomClient = partitionClients.map { case (_, clients) => clients(random.nextInt(clients.size)) } + + val (requestObservers, responseIterators) = randomClient.zipWithIndex.toArray.map { case (client, partition) => + // TODO this is kind of inefficient + val partitionCount = queries.count { case (partitions, _) => partitions.contains(partition) } + + val responseStreamObserver = new StreamObserverAdapter[SearchResponse](partitionCount) + val requestStreamObserver = client.search(responseStreamObserver) + + (requestStreamObserver, responseStreamObserver: Iterator[SearchResponse]) + }.unzip + + threadPool.submit(new Runnable { + override def run(): Unit = { + val queriesIterator = queries.iterator + + for { + (queryPartitions, vector) <- queriesIterator + last = !queriesIterator.hasNext + (observer, observerPartition) <- requestObservers.zipWithIndex + } { + + if (queryPartitions.contains(observerPartition)) { + val request = SearchRequest( + vector = vectorConverter(vector), + k = k + ) + observer.onNext(request) + } + if (last) { + observer.onCompleted() + } + } + } + }) + + val expectations = batch.zip(queries).map { case (row, (partitions, _)) => partitions -> row }.iterator + + new ResultsIterator(expectations, responseIterators: Array[Iterator[SearchResponse]], k) + } + + def saveIndex(path: String): Unit = { + val futures = partitionClients.flatMap { case (partition, clients) => + // only the primary replica saves the index + clients.headOption.map { client => + val request = SaveIndexRequest(s"$path/$partition") + client.saveIndex(request) + } + } + + val responses = Await.result(Future.sequence(futures), Duration.Inf) // TODO not sure if inf is smart + responses.foreach(println) // TODO remove + } + + def shutdown(): Unit = { + channels.foreach(_.shutdown()) + threadPool.shutdown() + } + + private class StreamObserverAdapter[T](expected: Int) extends StreamObserver[T] with Iterator[T] { + + private val queue = new LinkedBlockingQueue[Either[Throwable, T]] + private val counter = new AtomicInteger() + private val done = new AtomicBoolean(false) + + // ======================================== StreamObserver ======================================== + + override def onNext(value: T): Unit = { + queue.add(Right(value)) + counter.incrementAndGet() + } + + override def onError(t: Throwable): Unit = { + queue.add(Left(t)) + done.set(true) + } + + override def onCompleted(): Unit = { + done.set(true) + } + + // ========================================== Iterator ========================================== + + override def hasNext: Boolean = { + !queue.isEmpty || (counter.get() < expected && !done.get()) + } + + override def next(): T = queue.take() match { + case Right(value) => value + case Left(t) => throw t + } + } + + private class ResultsIterator( + iterator: Iterator[(Seq[Int], Row)], + partitionIterators: Array[Iterator[SearchResponse]], + k: Int + ) extends Iterator[Row] { + + override def hasNext: Boolean = iterator.hasNext + + override def next(): Row = { + val (partitions, row) = iterator.next() + + val responses = partitions.map(partitionIterators.apply).map(_.next()) + + val allResults = for { + response <- responses + result <- response.results + } yield idExtractor(result) -> distanceExtractor(result) + + val results = allResults + .sortBy(_._2)(distanceOrdering) + .take(k) + .map { case (id, distance) => Row(id, distance) } + + Row.fromSeq(row.toSeq :+ results) + } + } + +} + +class IndexClientFactory[TId, TVector, TDistance]( + vectorConverter: TVector => SearchRequest.Vector, + idExtractor: Result => TId, + distanceExtractor: Result => TDistance, + distanceOrdering: Ordering[TDistance] +) extends Serializable { + + def create(servers: Map[PartitionAndReplica, InetSocketAddress]): IndexClient[TId, TVector, TDistance] = { + new IndexClient(servers, vectorConverter, idExtractor, distanceExtractor, distanceOrdering) + } + +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala new file mode 100644 index 00000000..dcf0e37f --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala @@ -0,0 +1,65 @@ +package com.github.jelmerk.serving.server + +import scala.concurrent.{ExecutionContext, Future} +import scala.language.implicitConversions + +import com.github.jelmerk.knn.scalalike.{Index, Item} +import com.github.jelmerk.server.index._ +import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService +import io.grpc.stub.StreamObserver +import org.apache.commons.io.output.CountingOutputStream +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +class DefaultIndexService[TId, TVector, TItem <: Item[TId, TVector], TDistance]( + index: Index[TId, TVector, TItem, TDistance], + hadoopConfiguration: Configuration, + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance +)(implicit + executionContext: ExecutionContext +) extends IndexService { + + override def search(responseObserver: StreamObserver[SearchResponse]): StreamObserver[SearchRequest] = { + new StreamObserver[SearchRequest] { + override def onNext(request: SearchRequest): Unit = { + + val vector = vectorExtractor(request) + val nearest = index.findNearest(vector, request.k) + + val results = nearest.map { searchResult => + val id = resultIdConverter(searchResult.item.id) + val distance = resultDistanceConverter(searchResult.distance) + + Result(id, distance) + } + + val response = SearchResponse(results) + responseObserver.onNext(response) + } + + override def onError(t: Throwable): Unit = { + responseObserver.onError(t) + } + + override def onCompleted(): Unit = { + responseObserver.onCompleted() + } + } + } + + override def saveIndex(request: SaveIndexRequest): Future[SaveIndexResponse] = Future { + val path = new Path(request.path) + val fileSystem = path.getFileSystem(hadoopConfiguration) + + val outputStream = fileSystem.create(path) + + val countingOutputStream = new CountingOutputStream(outputStream) + + index.save(countingOutputStream) + + SaveIndexResponse(bytesWritten = countingOutputStream.getCount) + } + +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala new file mode 100644 index 00000000..3429e939 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala @@ -0,0 +1,84 @@ +package com.github.jelmerk.serving.server + +import java.net.{InetAddress, InetSocketAddress} +import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} + +import scala.concurrent.ExecutionContext +import scala.util.Try + +import com.github.jelmerk.knn.scalalike.{Index, Item} +import com.github.jelmerk.server.index.{IndexServiceGrpc, Result, SearchRequest} +import io.grpc.netty.NettyServerBuilder +import org.apache.hadoop.conf.Configuration + +class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + host: String, + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance, + index: Index[TId, TVector, TItem, TDistance], + hadoopConfig: Configuration, + threads: Int +) { + private val executor = new ThreadPoolExecutor( + threads, + threads, + 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue[Runnable]() + ) + + private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) + + private implicit val ec: ExecutionContext = ExecutionContext.global + private val service = + new DefaultIndexService(index, hadoopConfig, vectorExtractor, resultIdConverter, resultDistanceConverter) + + // Build the gRPC server + private val server = NettyServerBuilder + .forAddress(new InetSocketAddress(host, 0)) + .addService(IndexServiceGrpc.bindService(service, executionContext)) + .build() + + def start(): Unit = server.start() + + def address: InetSocketAddress = server.getListenSockets.get(0).asInstanceOf[InetSocketAddress] // TODO CLEANUP + + def awaitTermination(): Unit = { + server.awaitTermination() + } + + def isTerminated(): Boolean = { + server.isTerminated + } + + def shutdown(): Unit = { + Try(server.shutdown()) + Try(executor.shutdown()) + } +} + +class IndexServerFactory[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance +) extends Serializable { + + def create( + host: String, + index: Index[TId, TVector, TItem, TDistance], + hadoopConfig: Configuration, + threads: Int + ): IndexServer[TId, TVector, TItem, TDistance] = { + new IndexServer[TId, TVector, TItem, TDistance]( + host, + vectorExtractor, + resultIdConverter, + resultDistanceConverter, + index, + hadoopConfig, + threads + ) + + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala index c51e9dc0..e237704a 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala @@ -1,33 +1,22 @@ package com.github.jelmerk.spark.knn import java.io.InputStream -import java.net.InetAddress -import java.util.concurrent.{ - CountDownLatch, - ExecutionException, - FutureTask, - LinkedBlockingQueue, - ThreadLocalRandom, - ThreadPoolExecutor, - TimeUnit -} +import java.net.{InetAddress, InetSocketAddress} -import scala.Seq -import scala.annotation.tailrec import scala.language.{higherKinds, implicitConversions} import scala.reflect.ClassTag import scala.reflect.runtime.universe._ -import scala.util.Try import scala.util.control.NonFatal -import com.github.jelmerk.knn.{Jdk17DistanceFunctions, ObjectSerializer} +import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike._ -import com.github.jelmerk.knn.scalalike.jdk17DistanceFunctions._ -import com.github.jelmerk.knn.util.NamedThreadFactory -import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions +import com.github.jelmerk.registration.client.RegistrationClient +import com.github.jelmerk.registration.server.{PartitionAndReplica, RegistrationServerFactory} +import com.github.jelmerk.serving.client.IndexClientFactory +import com.github.jelmerk.serving.server.IndexServerFactory import com.github.jelmerk.spark.util.SerializableConfiguration -import org.apache.hadoop.fs.{FileUtil, Path} -import org.apache.spark.{Partitioner, TaskContext} +import org.apache.hadoop.fs.Path +import org.apache.spark.{Partitioner, SparkContext, SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.SQLDataTypes.VectorType @@ -35,156 +24,186 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} import org.apache.spark.ml.util.{MLReader, MLWriter} -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.rdd.RDD +import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -private[knn] case class IntDoubleArrayIndexItem(id: Int, vector: Array[Double]) extends Item[Int, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongDoubleArrayIndexItem(id: Long, vector: Array[Double]) extends Item[Long, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringDoubleArrayIndexItem(id: String, vector: Array[Double]) - extends Item[String, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class IntFloatArrayIndexItem(id: Int, vector: Array[Float]) extends Item[Int, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongFloatArrayIndexItem(id: Long, vector: Array[Float]) extends Item[Long, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) extends Item[String, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class IntVectorIndexItem(id: Int, vector: Vector) extends Item[Int, Vector] { - override def dimensions: Int = vector.size -} +import org.json4s.jackson.Serialization.{read, write} + +private[knn] case class ModelMetaData( + `class`: String, + timestamp: Long, + sparkVersion: String, + uid: String, + identifierType: String, + vectorType: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + paramMap: Map[String, Any] +) + +private[knn] trait IndexType { + + /** Type of index. */ + protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] <: Index[TId, TVector, TItem, TDistance] -private[knn] case class LongVectorIndexItem(id: Long, vector: Vector) extends Item[Long, Vector] { - override def dimensions: Int = vector.size + protected implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] } -private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extends Item[String, Vector] { - override def dimensions: Int = vector.size -} +private[knn] trait IndexCreator extends IndexType { -/** Neighbor of an item. - * - * @param neighbor - * identifies the neighbor - * @param distance - * distance to the item - * - * @tparam TId - * type of the index item identifier - * @tparam TDistance - * type of distance - */ -private[knn] case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) - -/** Common params for KnnAlgorithm and KnnModel. - */ -private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPredictionCol { - - /** Param for the column name for the query identifier. + /** Create the index used to do the nearest neighbor search. * - * @group param - */ - final val queryIdentifierCol = new Param[String](this, "queryIdentifierCol", "column name for the query identifier") - - /** @group getParam */ - final def getQueryIdentifierCol: String = $(queryIdentifierCol) - - /** Param for the column name for the query partitions. + * @param dimensions + * dimensionality of the items stored in the index + * @param maxItemCount + * maximum number of items the index can hold + * @param distanceFunction + * the distance function + * @param distanceOrdering + * the distance ordering + * @param idSerializer + * invoked for serializing ids when saving the index + * @param itemSerializer + * invoked for serializing items when saving items * - * @group param + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * the created index */ - final val queryPartitionsCol = new Param[String](this, "queryPartitionsCol", "column name for the query partitions") - - /** @group getParam */ - final def getQueryPartitionsCol: String = $(queryPartitionsCol) + protected def createIndex[ + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance + ](dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit + distanceOrdering: Ordering[TDistance], + idSerializer: ObjectSerializer[TId], + itemSerializer: ObjectSerializer[TItem] + ): TIndex[TId, TVector, TItem, TDistance] - /** Param for number of neighbors to find (> 0). Default: 5 + /** Create an immutable empty index. * - * @group param + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * the created index */ - final val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0)) + protected def emptyIndex[ + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance + ]: TIndex[TId, TVector, TItem, TDistance] +} - /** @group getParam */ - final def getK: Int = $(k) +private[knn] trait IndexLoader extends IndexType { - /** Param that indicates whether to not return the a candidate when it's identifier equals the query identifier - * Default: false + /** Load an index * - * @group param - */ - final val excludeSelf = - new BooleanParam(this, "excludeSelf", "whether to include the row identifier as a candidate neighbor") - - /** @group getParam */ - final def getExcludeSelf: Boolean = $(excludeSelf) - - /** Param for the threshold value for inclusion. -1 indicates no threshold Default: -1 + * @param inputStream + * InputStream to restore the index from + * @param minCapacity + * loaded index needs to have space for at least this man additional items * - * @group param + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * create an index */ - final val similarityThreshold = - new DoubleParam(this, "similarityThreshold", "do not return neighbors further away than this distance") + protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): TIndex[TId, TVector, TItem, TDistance] +} - /** @group getParam */ - final def getSimilarityThreshold: Double = $(similarityThreshold) +private[knn] trait ModelCreator[TModel <: KnnModelBase[TModel]] { - /** Param that specifies the number of index replicas to create when querying the index. More replicas means you can - * execute more queries in parallel at the expense of increased resource usage. Default: 0 + /** Creates the model to be returned from fitting the data. * - * @group param + * @param uid + * identifier + * @param indices + * map of index servers + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * model */ - final val numReplicas = new IntParam(this, "numReplicas", "number of index replicas to create when querying") + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): TModel +} - /** @group getParam */ - final def getNumReplicas: Int = $(numReplicas) +/** Common params for KnnAlgorithm and KnnModel. */ +private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPredictionCol { - /** Param that specifies the number of threads to use. Default: number of processors available to the Java virtual - * machine + /** Param for the column name for the query partitions. * * @group param */ - final val parallelism = new IntParam(this, "parallelism", "number of threads to use") + final val queryPartitionsCol = new Param[String](this, "queryPartitionsCol", "column name for the query partitions") /** @group getParam */ - final def getParallelism: Int = $(parallelism) + final def getQueryPartitionsCol: String = $(queryPartitionsCol) - /** Param for the output format to produce. One of "full", "minimal" Setting this to minimal is more efficient when - * all you need is the identifier with its neighbors - * - * Default: "full" + /** Param for number of neighbors to find (> 0). Default: 5 * * @group param */ - final val outputFormat = new Param[String](this, "outputFormat", "output format to produce") + final val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0)) /** @group getParam */ - final def getOutputFormat: String = $(outputFormat) + final def getK: Int = $(k) setDefault( - k -> 5, - predictionCol -> "prediction", - featuresCol -> "features", - excludeSelf -> false, - similarityThreshold -> -1, - outputFormat -> "full" + k -> 5, + predictionCol -> "prediction", + featuresCol -> "features" ) protected def validateAndTransformSchema(schema: StructType, identifierDataType: DataType): StructType = { @@ -200,20 +219,11 @@ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPre val neighborsField = StructField(getPredictionCol, new ArrayType(predictionStruct, containsNull = false)) - getOutputFormat match { - case "minimal" if !isSet(queryIdentifierCol) => - throw new IllegalArgumentException("queryIdentifierCol must be set when using outputFormat minimal.") - case "minimal" => - new StructType() - .add(schema(getQueryIdentifierCol)) - .add(neighborsField) - case _ => - if (schema.fieldNames.contains(getPredictionCol)) { - throw new IllegalArgumentException(s"Output column $getPredictionCol already exists.") - } - schema - .add(neighborsField) + if (schema.fieldNames.contains(getPredictionCol)) { + throw new IllegalArgumentException(s"Output column $getPredictionCol already exists.") } + schema + .add(neighborsField) } } @@ -230,13 +240,33 @@ private[knn] trait KnnAlgorithmParams extends KnnModelParams { /** @group getParam */ final def getIdentifierCol: String = $(identifierCol) - /** Number of partitions (default: 1) + /** Number of partitions */ final val numPartitions = new IntParam(this, "numPartitions", "number of partitions", ParamValidators.gt(0)) /** @group getParam */ final def getNumPartitions: Int = $(numPartitions) + /** Param that specifies the number of index replicas to create when querying the index. More replicas means you can + * execute more queries in parallel at the expense of increased resource usage. Default: 0 + * + * @group param + */ + final val numReplicas = + new IntParam(this, "numReplicas", "number of index replicas to create when querying", ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getNumReplicas: Int = $(numReplicas) + + /** Param that specifies the number of threads to use. + * + * @group param + */ + final val numThreads = new IntParam(this, "numThreads", "number of threads to use per index", ParamValidators.gt(0)) + + /** @group getParam */ + final def getNumThreads: Int = $(numThreads) + /** Param for the distance function to use. One of "bray-curtis", "canberra", "cosine", "correlation", "euclidean", * "inner-product", "manhattan" or the fully qualified classname of a distance function Default: "cosine" * @@ -254,14 +284,18 @@ private[knn] trait KnnAlgorithmParams extends KnnModelParams { /** @group getParam */ final def getPartitionCol: String = $(partitionCol) - /** Param to the initial model. All the vectors from the initial model will included in the final output model. + /** Param to the initial model. All the vectors from the initial model will be included in the final output model. */ final val initialModelPath = new Param[String](this, "initialModelPath", "path to the initial model") /** @group getParam */ final def getInitialModelPath: String = $(initialModelPath) - setDefault(identifierCol -> "id", distanceFunction -> "cosine", numPartitions -> 1, numReplicas -> 0) + setDefault(identifierCol -> "id", distanceFunction -> "cosine", numReplicas -> 0) +} + +object KnnModelWriter { + private implicit val format: Formats = DefaultFormats.withLong } /** Persists a knn model. @@ -286,67 +320,47 @@ private[knn] class KnnModelWriter[ TModel <: KnnModelBase[TModel], TId: TypeTag, TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag, + TItem <: Item[TId, TVector] with Product, + TDistance, TIndex <: Index[TId, TVector, TItem, TDistance] ](instance: TModel with KnnModelOps[TModel, TId, TVector, TItem, TDistance, TIndex]) extends MLWriter { + import KnnModelWriter._ + override protected def saveImpl(path: String): Unit = { - val params = JObject( - instance - .extractParamMap() - .toSeq - .toList - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - .map { case ParamPair(param, value) => - val fieldName = param.name - val fieldValue = mapper.readValue(param.jsonEncode(value), classOf[JValue]) - JField(fieldName, fieldValue) - } - ) - val metaData = JObject( - List( - JField("class", JString(instance.getClass.getName)), - JField("timestamp", JLong(System.currentTimeMillis())), - JField("sparkVersion", JString(sc.version)), - JField("uid", JString(instance.uid)), - JField("identifierType", JString(typeDescription[TId])), - JField("vectorType", JString(typeDescription[TVector])), - JField("partitions", JInt(instance.getNumPartitions)), - JField("paramMap", params) - ) + val metadata = ModelMetaData( + `class` = instance.getClass.getName, + timestamp = System.currentTimeMillis(), + sparkVersion = sc.version, + uid = instance.uid, + identifierType = typeDescription[TId], + vectorType = typeDescription[TVector], + numPartitions = instance.numPartitions, + numReplicas = instance.numReplicas, + numThreads = instance.numThreads, + paramMap = toMap(instance.extractParamMap()) ) val metadataPath = new Path(path, "metadata").toString - sc.parallelize(Seq(compact(metaData)), numSlices = 1).saveAsTextFile(metadataPath) + sc.parallelize(Seq(write(metadata)), numSlices = 1).saveAsTextFile(metadataPath) val indicesPath = new Path(path, "indices").toString - val modelOutputDir = instance.outputDir - - val serializableConfiguration = new SerializableConfiguration(sc.hadoopConfiguration) - - sc.range(start = 0, end = instance.getNumPartitions).foreach { partitionId => - val originPath = new Path(modelOutputDir, partitionId.toString) - val originFileSystem = originPath.getFileSystem(serializableConfiguration.value) - - if (originFileSystem.exists(originPath)) { - val destinationPath = new Path(indicesPath, partitionId.toString) - val destinationFileSystem = destinationPath.getFileSystem(serializableConfiguration.value) - FileUtil.copy( - originFileSystem, - originPath, - destinationFileSystem, - destinationPath, - false, - serializableConfiguration.value - ) - } + val client = instance.clientFactory.create(instance.indexAddresses) + try { + client.saveIndex(indicesPath) + } finally { + client.shutdown() } } + private def toMap(paramMap: ParamMap): Map[String, Any] = + paramMap.toSeq.map { case ParamPair(param, value) => param.name -> value }.toMap + + // TODO should i make this an implicit like elsewhere + private def typeDescription[T: TypeTag] = typeOf[T] match { case t if t =:= typeOf[Int] => "int" case t if t =:= typeOf[Long] => "long" @@ -358,17 +372,23 @@ private[knn] class KnnModelWriter[ } } +object KnnModelReader { + private implicit val format: Formats = DefaultFormats.withLong +} + /** Reads a knn model from persistent storage. * - * @param ev - * classtag * @tparam TModel * type of model */ -private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](implicit ev: ClassTag[TModel]) - extends MLReader[TModel] { +private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]] + extends MLReader[TModel] + with IndexLoader + with IndexServing + with ModelCreator[TModel] + with Serializable { - private implicit val format: Formats = DefaultFormats + import KnnModelReader._ override def load(path: String): TModel = { @@ -376,83 +396,95 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](impli val metadataStr = sc.textFile(metadataPath, 1).first() - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - val metadata = mapper.readValue(metadataStr, classOf[JValue]) - - val uid = (metadata \ "uid").extract[String] - - val identifierType = (metadata \ "identifierType").extract[String] - val vectorType = (metadata \ "vectorType").extract[String] - val partitions = (metadata \ "partitions").extract[Int] - - val paramMap = (metadata \ "paramMap").extract[JObject] + val metadata = read[ModelMetaData](metadataStr) - val indicesPath = new Path(path, "indices").toString - - val model = (identifierType, vectorType) match { + val model = (metadata.identifierType, metadata.vectorType) match { case ("int", "float_array") => - createModel[Int, Array[Float], IntFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[Int, Array[Float], IntFloatArrayIndexItem, Float](path, metadata) case ("int", "double_array") => - createModel[Int, Array[Double], IntDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("int", "vector") => createModel[Int, Vector, IntVectorIndexItem, Double](uid, indicesPath, partitions) + typedLoad[Int, Array[Double], IntDoubleArrayIndexItem, Double](path, metadata) + case ("int", "vector") => + typedLoad[Int, Vector, IntVectorIndexItem, Double](path, metadata) case ("long", "float_array") => - createModel[Long, Array[Float], LongFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[Long, Array[Float], LongFloatArrayIndexItem, Float](path, metadata) case ("long", "double_array") => - createModel[Long, Array[Double], LongDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("long", "vector") => createModel[Long, Vector, LongVectorIndexItem, Double](uid, indicesPath, partitions) + typedLoad[Long, Array[Double], LongDoubleArrayIndexItem, Double](path, metadata) + case ("long", "vector") => + typedLoad[Long, Vector, LongVectorIndexItem, Double](path, metadata) case ("string", "float_array") => - createModel[String, Array[Float], StringFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[String, Array[Float], StringFloatArrayIndexItem, Float](path, metadata) case ("string", "double_array") => - createModel[String, Array[Double], StringDoubleArrayIndexItem, Double](uid, indicesPath, partitions) + typedLoad[String, Array[Double], StringDoubleArrayIndexItem, Double](path, metadata) case ("string", "vector") => - createModel[String, Vector, StringVectorIndexItem, Double](uid, indicesPath, partitions) - case _ => + typedLoad[String, Vector, StringVectorIndexItem, Double](path, metadata) + case (identifierType, vectorType) => throw new IllegalStateException( s"Cannot create model for identifier type $identifierType and vector type $vectorType." ) } - paramMap.obj.foreach { case (paramName, jsonValue) => + model + + } + + private def typedLoad[ + TId: TypeTag: ClassTag, + TVector: TypeTag: ClassTag, + TItem <: Item[TId, TVector] with Product: TypeTag: ClassTag, + TDistance: TypeTag: ClassTag + ](path: String, metadata: ModelMetaData)(implicit + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): TModel = { + + val indicesPath = new Path(path, "indices") + + val taskReqs = new TaskResourceRequests().cpus(metadata.numThreads) + val profile = new ResourceProfileBuilder().require(taskReqs).build() + + val serializableConfiguration = new SerializableConfiguration(sc.hadoopConfiguration) + + val partitionPaths = (0 until metadata.numPartitions).map { partitionId => + partitionId -> new Path(indicesPath, partitionId.toString) + } + + val indexRdd = sc + .makeRDD(partitionPaths) + .partitionBy(new PartitionIdPartitioner(metadata.numPartitions)) + .mapPartitions { it => + val (partitionId, indexPath) = it.next() + val fs = indexPath.getFileSystem(serializableConfiguration.value) + + logInfo(partitionId, s"Loading index from $indexPath") + val inputStream = fs.open(indexPath) + val index = loadIndex[TId, TVector, TItem, TDistance](inputStream, 0) + logInfo(partitionId, s"Finished loading index from $indexPath") + Iterator(index) + } + .withResources(profile) + + val servers = serve(metadata.uid, indexRdd, metadata.numPartitions, metadata.numReplicas, metadata.numThreads) + + val model = createModel( + metadata.uid, + metadata.numPartitions, + metadata.numReplicas, + metadata.numThreads, + sc, + servers, + clientFactory + ) + + metadata.paramMap.foreach { case (paramName, value) => val param = model.getParam(paramName) - model.set(param, param.jsonDecode(compact(render(jsonValue)))) + model.set(param, value) } model } - /** Creates the model to be returned from fitting the data. - * - * @param uid - * identifier - * @param outputDir - * directory containing the persisted indices - * @param numPartitions - * number of index partitions - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * model - */ - protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): TModel - } /** Base class for nearest neighbor search models. @@ -462,12 +494,8 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](impli */ private[knn] abstract class KnnModelBase[TModel <: KnnModelBase[TModel]] extends Model[TModel] with KnnModelParams { - private[knn] def outputDir: String - - def getNumPartitions: Int - - /** @group setParam */ - def setQueryIdentifierCol(value: String): this.type = set(queryIdentifierCol, value) + private[knn] def sparkContext: SparkContext + @volatile private[knn] var destroyed: Boolean = false /** @group setParam */ def setQueryPartitionsCol(value: String): this.type = set(queryPartitionsCol, value) @@ -481,21 +509,14 @@ private[knn] abstract class KnnModelBase[TModel <: KnnModelBase[TModel]] extends /** @group setParam */ def setK(value: Int): this.type = set(k, value) - /** @group setParam */ - def setExcludeSelf(value: Boolean): this.type = set(excludeSelf, value) - - /** @group setParam */ - def setSimilarityThreshold(value: Double): this.type = set(similarityThreshold, value) - - /** @group setParam */ - def setNumReplicas(value: Int): this.type = set(numReplicas, value) - - /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) - - /** @group setParam */ - def setOutputFormat(value: String): this.type = set(outputFormat, value) + override def finalize(): Unit = { + destroy() + } + def destroy(): Unit = { + sparkContext.cancelJobGroup(uid) + destroyed = true + } } /** Contains the core knn search logic @@ -523,294 +544,62 @@ private[knn] trait KnnModelOps[ ] { this: TModel with KnnModelParams => - protected def loadIndex(in: InputStream): TIndex - - protected def typedTransform(dataset: Dataset[_])(implicit - tId: TypeTag[TId], - tVector: TypeTag[TVector], - tDistance: TypeTag[TDistance], - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): DataFrame = { - - if (!isSet(queryIdentifierCol) && getExcludeSelf) { - throw new IllegalArgumentException("QueryIdentifierCol must be defined when excludeSelf is true.") - } - - if (isSet(queryIdentifierCol)) typedTransformWithQueryCol[TId](dataset, getQueryIdentifierCol) - else - typedTransformWithQueryCol[Long](dataset.withColumn("_query_id", monotonically_increasing_id), "_query_id") - .drop("_query_id") - } - - protected def typedTransformWithQueryCol[TQueryId](dataset: Dataset[_], queryIdCol: String)(implicit - tId: TypeTag[TId], - tVector: TypeTag[TVector], - tDistance: TypeTag[TDistance], - tQueryId: TypeTag[TQueryId], - evId: ClassTag[TId], - evVector: ClassTag[TVector], - evQueryId: ClassTag[TQueryId], - distanceNumeric: Numeric[TDistance] - ): DataFrame = { - import dataset.sparkSession.implicits._ - import distanceNumeric._ - - implicit val encoder: Encoder[TQueryId] = ExpressionEncoder() - implicit val neighborOrdering: Ordering[Neighbor[TId, TDistance]] = Ordering.by(_.distance) - - val serializableHadoopConfiguration = new SerializableConfiguration( - dataset.sparkSession.sparkContext.hadoopConfiguration - ) - - // construct the queries to the distributed indices. when query partitions are specified we only query those partitions - // otherwise we query all partitions - val logicalPartitionAndQueries = - if (isDefined(queryPartitionsCol)) - dataset - .select( - col(getQueryPartitionsCol), - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(Seq[Int], TQueryId, TVector)] - .rdd - .flatMap { case (queryPartitions, queryId, vector) => - queryPartitions.map { partition => (partition, (queryId, vector)) } - } - else - dataset - .select( - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(TQueryId, TVector)] - .rdd - .flatMap { case (queryId, vector) => - Range(0, getNumPartitions).map { partition => - (partition, (queryId, vector)) - } - } - - val numPartitionCopies = getNumReplicas + 1 - - val physicalPartitionAndQueries = logicalPartitionAndQueries - .map { case (partition, (queryId, vector)) => - val randomCopy = ThreadLocalRandom.current().nextInt(numPartitionCopies) - val physicalPartition = (partition * numPartitionCopies) + randomCopy - (physicalPartition, (queryId, vector)) - } - .partitionBy(new PartitionIdPassthrough(getNumPartitions * numPartitionCopies)) - - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) - - val neighborsOnAllQueryPartitions = physicalPartitionAndQueries - .mapPartitions { queriesWithPartition => - val queries = queriesWithPartition.map(_._2) + implicit protected def idTypeTag: TypeTag[TId] - // load the partitioned index and execute all queries. + implicit protected def vectorTypeTag: TypeTag[TVector] - val physicalPartitionId = TaskContext.getPartitionId() + private[knn] def numPartitions: Int - val logicalPartitionId = physicalPartitionId / numPartitionCopies - val replica = physicalPartitionId % numPartitionCopies + private[knn] def numReplicas: Int - val indexPath = new Path(outputDir, logicalPartitionId.toString) + private[knn] def numThreads: Int - val fileSystem = indexPath.getFileSystem(serializableHadoopConfiguration.value) + private[knn] def indexAddresses: Map[PartitionAndReplica, InetSocketAddress] - if (!fileSystem.exists(indexPath)) Iterator.empty - else { + private[knn] def clientFactory: IndexClientFactory[TId, TVector, TDistance] - logInfo( - logicalPartitionId, - replica, - s"started loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}" - ) - val index = loadIndex(fileSystem.open(indexPath)) - logInfo( - logicalPartitionId, - replica, - s"finished loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}" - ) - - // execute queries in parallel on multiple threads - new Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] { - - private[this] var first = true - private[this] var count = 0 - - private[this] val batchSize = 1000 - private[this] val queue = - new LinkedBlockingQueue[(TQueryId, Seq[Neighbor[TId, TDistance]])](batchSize * numThreads) - private[this] val executorService = new ThreadPoolExecutor( - numThreads, - numThreads, - 60L, - TimeUnit.SECONDS, - new LinkedBlockingQueue[Runnable], - new NamedThreadFactory("searcher-%d") - ) { - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - - Option(t) - .orElse { - r match { - case t: FutureTask[_] => - Try(t.get()).failed.toOption.map { - case e: ExecutionException => e.getCause - case e: InterruptedException => - Thread.currentThread().interrupt() - e - case NonFatal(e) => e - } - case _ => None - } - } - .foreach { e => - logError("Error in worker.", e) - } - } - } - executorService.allowCoreThreadTimeOut(true) - - private[this] val activeWorkers = new CountDownLatch(numThreads) - Range(0, numThreads).map(_ => new Worker(queries, activeWorkers, batchSize)).foreach(executorService.submit) - - override def hasNext: Boolean = { - if (!queue.isEmpty) true - else if (queries.synchronized { queries.hasNext }) true - else { - // in theory all workers could have just picked up the last new work but not started processing any of it. - if (!activeWorkers.await(2, TimeUnit.MINUTES)) { - throw new IllegalStateException("Workers failed to complete.") - } - !queue.isEmpty - } - } - - override def next(): (TQueryId, Seq[Neighbor[TId, TDistance]]) = { - if (first) { - logInfo( - logicalPartitionId, - replica, - s"started querying on host ${InetAddress.getLocalHost.getHostName} with ${sys.runtime.availableProcessors} available processors." - ) - first = false - } - - val value = queue.poll(1, TimeUnit.MINUTES) - - count += 1 - - if (!hasNext) { - logInfo( - logicalPartitionId, - replica, - s"finished querying $count items on host ${InetAddress.getLocalHost.getHostName}" - ) - - executorService.shutdown() - } - - value - } - - class Worker(queries: Iterator[(TQueryId, TVector)], activeWorkers: CountDownLatch, batchSize: Int) - extends Runnable { - - private[this] var work = List.empty[(TQueryId, TVector)] - - private[this] val fetchSize = - if (getExcludeSelf) getK + 1 - else getK - - @tailrec final override def run(): Unit = { - - work.foreach { case (id, vector) => - val neighbors = index - .findNearest(vector, fetchSize) - .collect { - case SearchResult(item, distance) - if (!getExcludeSelf || item.id != id) && (getSimilarityThreshold < 0 || distance.toDouble < getSimilarityThreshold) => - Neighbor[TId, TDistance](item.id, distance) - } + protected def loadIndex(in: InputStream): TIndex - queue.put(id -> neighbors) - } + override def transform(dataset: Dataset[_]): DataFrame = { + if (destroyed) { + throw new IllegalStateException("Model is destroyed.") + } - work = queries.synchronized { - queries.take(batchSize).toList - } + val localIndexAddr = indexAddresses + val localClientFactory = clientFactory + val k = getK + val featuresCol = getFeaturesCol + val partitionsColOpt = if (isDefined(queryPartitionsCol)) Some(getQueryPartitionsCol) else None - if (work.nonEmpty) { - run() - } else { - activeWorkers.countDown() - } - } - } - } - } - } - .toDS() + val outputSchema = transformSchema(dataset.schema) - // take the best k results from all partitions + implicit val encoder: Encoder[Row] = ExpressionEncoder(RowEncoder.encoderFor(outputSchema, lenient = false)) - val topNeighbors = neighborsOnAllQueryPartitions - .groupByKey { case (queryId, _) => queryId } - .flatMapGroups { (queryId, groups) => - val allNeighbors = groups.flatMap { case (_, neighbors) => neighbors }.toList - Iterator.single(queryId -> allNeighbors.sortBy(_.distance).take(getK)) + dataset.toDF + .mapPartitions { it => + new QueryIterator(localIndexAddr, localClientFactory, it, batchSize = 100, k, featuresCol, partitionsColOpt) } - .toDF(queryIdCol, getPredictionCol) - - if (getOutputFormat == "minimal") topNeighbors - else dataset.join(topNeighbors, Seq(queryIdCol)) } - protected def typedTransformSchema[T: TypeTag](schema: StructType): StructType = { - val idDataType = typeOf[T] match { - case t if t =:= typeOf[Int] => IntegerType - case t if t =:= typeOf[Long] => LongType - case _ => StringType - } + override def transformSchema(schema: StructType): StructType = { + val idDataType = ScalaReflection.encoderFor[TId].dataType validateAndTransformSchema(schema, idDataType) } - private def logInfo(partition: Int, replica: Int, message: String): Unit = - logInfo(f"partition $partition%04d replica $replica%04d: $message") - } private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](override val uid: String) extends Estimator[TModel] - with KnnAlgorithmParams { - - /** Type of index. - * - * @tparam TId - * Type of the external identifier of an item - * @tparam TVector - * Type of the vector to perform distance calculation on - * @tparam TItem - * Type of items stored in the index - * @tparam TDistance - * Type of distance between items (expect any numeric type: float, double, int, ..) - */ - protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] <: Index[TId, TVector, TItem, TDistance] + with ModelLogging + with KnnAlgorithmParams + with IndexCreator + with IndexLoader + with IndexServing + with ModelCreator[TModel] { /** @group setParam */ def setIdentifierCol(value: String): this.type = set(identifierCol, value) - /** @group setParam */ - def setQueryIdentifierCol(value: String): this.type = set(queryIdentifierCol, value) - /** @group setParam */ def setPartitionCol(value: String): this.type = set(partitionCol, value) @@ -832,20 +621,11 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid /** @group setParam */ def setDistanceFunction(value: String): this.type = set(distanceFunction, value) - /** @group setParam */ - def setExcludeSelf(value: Boolean): this.type = set(excludeSelf, value) - - /** @group setParam */ - def setSimilarityThreshold(value: Double): this.type = set(similarityThreshold, value) - /** @group setParam */ def setNumReplicas(value: Int): this.type = set(numReplicas, value) /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) - - /** @group setParam */ - def setOutputFormat(value: String): this.type = set(outputFormat, value) + def setNumThreads(value: Int): this.type = set(numThreads, value) def setInitialModelPath(value: String): this.type = set(initialModelPath, value) @@ -884,126 +664,26 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid override def copy(extra: ParamMap): Estimator[TModel] = defaultCopy(extra) - /** Create the index used to do the nearest neighbor search. - * - * @param dimensions - * dimensionality of the items stored in the index - * @param maxItemCount - * maximum number of items the index can hold - * @param distanceFunction - * the distance function - * @param distanceOrdering - * the distance ordering - * @param idSerializer - * invoked for serializing ids when saving the index - * @param itemSerializer - * invoked for serializing items when saving items - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * create an index - */ - protected def createIndex[ - TId, - TVector, - TItem <: Item[TId, TVector] with Product, - TDistance - ](dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit - distanceOrdering: Ordering[TDistance], - idSerializer: ObjectSerializer[TId], - itemSerializer: ObjectSerializer[TItem] - ): TIndex[TId, TVector, TItem, TDistance] - - /** Load an index - * - * @param inputStream - * InputStream to restore the index from - * @param minCapacity - * loaded index needs to have space for at least this man additional items - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * create an index - */ - protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): TIndex[TId, TVector, TItem, TDistance] - - /** Creates the model to be returned from fitting the data. - * - * @param uid - * identifier - * @param outputDir - * directory containing the persisted indices - * @param numPartitions - * number of index partitions - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * model - */ - protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): TModel - private def typedFit[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag + TId: TypeTag: ClassTag, + TVector: TypeTag: ClassTag, + TItem <: Item[TId, TVector] with Product: TypeTag: ClassTag, + TDistance: TypeTag: ClassTag ](dataset: Dataset[_])(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - evItem: ClassTag[TItem], distanceNumeric: Numeric[TDistance], distanceFunctionFactory: String => DistanceFunction[TVector, TDistance], idSerializer: ObjectSerializer[TId], - itemSerializer: ObjectSerializer[TItem] + itemSerializer: ObjectSerializer[TItem], + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance], + indexClientFactory: IndexClientFactory[TId, TVector, TDistance] ): TModel = { val sc = dataset.sparkSession val sparkContext = sc.sparkContext - val serializableHadoopConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - import sc.implicits._ - val cacheFolder = sparkContext.getConf.get(key = "spark.hnswlib.settings.index.cache_folder", defaultValue = "/tmp") - - val outputDir = new Path(cacheFolder, s"${uid}_${System.currentTimeMillis()}").toString - - sparkContext.addSparkListener(new CleanupListener(outputDir, serializableHadoopConfiguration)) - - // read the id and vector from the input dataset and and repartition them over numPartitions amount of partitions. + // read the id and vector from the input dataset, repartition them over numPartitions amount of partitions. // if the data is pre-partitioned by the user repartition the input data by the user defined partition key, use a // hash of the item id otherwise. val partitionedIndexItems = { @@ -1015,23 +695,20 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid ) .as[(Int, TItem)] .rdd - .partitionBy(new PartitionIdPassthrough(getNumPartitions)) + .partitionBy(new PartitionIdPartitioner(getNumPartitions)) .values - .toDS else dataset .select(col(getIdentifierCol).as("id"), col(getFeaturesCol).as("vector")) .as[TItem] - .repartition(getNumPartitions, $"id") + .rdd + .repartition(getNumPartitions) } // On each partition collect all the items into memory and construct the HNSW indices. // Save these indices to the hadoop filesystem - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) + val numThreads = getNumThreads val initialModelOutputDir = if (isSet(initialModelPath)) Some(new Path(getInitialModelPath, "indices").toString) @@ -1039,14 +716,16 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid val serializableConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - partitionedIndexItems - .foreachPartition { it: Iterator[TItem] => - if (it.hasNext) { - val partitionId = TaskContext.getPartitionId() + val taskReqs = new TaskResourceRequests().cpus(numThreads) + val profile = new ResourceProfileBuilder().require(taskReqs).build() + + val indexRdd = partitionedIndexItems + .mapPartitions( + (it: Iterator[TItem]) => { val items = it.toSeq - logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + val partitionId = TaskContext.getPartitionId() val existingIndexOption = initialModelOutputDir .flatMap { dir => @@ -1054,6 +733,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid val fs = indexPath.getFileSystem(serializableConfiguration.value) if (fs.exists(indexPath)) Some { + logInfo(partitionId, s"Loading existing index from $indexPath") val inputStream = fs.open(indexPath) loadIndex[TId, TVector, TItem, TDistance](inputStream, items.size) } @@ -1063,13 +743,18 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid } } + logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + val index = existingIndexOption + .filter(_.nonEmpty) .getOrElse( - createIndex[TId, TVector, TItem, TDistance]( - items.head.dimensions, - items.size, - distanceFunctionFactory(getDistanceFunction) - ) + items.headOption.fold(emptyIndex[TId, TVector, TItem, TDistance]) { item => + createIndex[TId, TVector, TItem, TDistance]( + item.dimensions, + items.size, + distanceFunctionFactory(getDistanceFunction) + ) + } ) index.addAll( @@ -1079,97 +764,179 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid numThreads = numThreads ) - logInfo(partitionId, s"finished indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + logInfo(partitionId, s"finished indexing ${items.size} items") - val path = new Path(outputDir, partitionId.toString) - val fileSystem = path.getFileSystem(serializableHadoopConfiguration.value) + Iterator(index) + }, + preservesPartitioning = true + ) + .withResources(profile) - val outputStream = fileSystem.create(path) + val modelUid = uid + "_" + System.currentTimeMillis().toString - logInfo(partitionId, s"started saving index to $path on host ${InetAddress.getLocalHost.getHostName}") + val registrations = + serve[TId, TVector, TItem, TDistance](modelUid, indexRdd, getNumPartitions, getNumReplicas, numThreads) - index.save(outputStream) + logInfo("All index replicas have successfully registered.") - logInfo(partitionId, s"finished saving index to $path on host ${InetAddress.getLocalHost.getHostName}") - } + registrations.toList + .sortBy { case (pnr, _) => (pnr.partitionNum, pnr.replicaNum) } + .foreach { case (pnr, a) => + logInfo(pnr.partitionNum, pnr.replicaNum, s"Index registered as ${a.getHostName}:${a.getPort}") } - createModel[TId, TVector, TItem, TDistance](uid, outputDir, getNumPartitions) + createModel[TId, TVector, TItem, TDistance]( + modelUid, + getNumPartitions, + getNumReplicas, + getNumThreads, + sparkContext, + registrations, + indexClientFactory + ) } - private def logInfo(partition: Int, message: String): Unit = logInfo(f"partition $partition%04d: $message") - - implicit private def floatArrayDistanceFunction(name: String): DistanceFunction[Array[Float], Float] = - (name, vectorApiAvailable) match { - case ("bray-curtis", true) => vectorFloat128BrayCurtisDistance - case ("bray-curtis", _) => floatBrayCurtisDistance - case ("canberra", true) => vectorFloat128CanberraDistance - case ("canberra", _) => floatCanberraDistance - case ("correlation", _) => floatCorrelationDistance - case ("cosine", true) => vectorFloat128CosineDistance - case ("cosine", _) => floatCosineDistance - case ("euclidean", true) => vectorFloat128EuclideanDistance - case ("euclidean", _) => floatEuclideanDistance - case ("inner-product", true) => vectorFloat128InnerProduct - case ("inner-product", _) => floatInnerProduct - case ("manhattan", true) => vectorFloat128ManhattanDistance - case ("manhattan", _) => floatManhattanDistance - case (value, _) => userDistanceFunction(value) - } +} - implicit private def doubleArrayDistanceFunction(name: String): DistanceFunction[Array[Double], Double] = name match { - case "bray-curtis" => doubleBrayCurtisDistance - case "canberra" => doubleCanberraDistance - case "correlation" => doubleCorrelationDistance - case "cosine" => doubleCosineDistance - case "euclidean" => doubleEuclideanDistance - case "inner-product" => doubleInnerProduct - case "manhattan" => doubleManhattanDistance - case value => userDistanceFunction(value) - } +/** Partitioner that uses precomputed partitions. Each partition id is its own partition + * + * @param numPartitions + * number of partitions + */ +private[knn] class PartitionIdPartitioner(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** Partitioner that uses precomputed partitions. Each unique partition and replica combination its own partition + * + * @param partitions + * the total number of partitions + * @param replicas + * the total number of replicas + */ +private[knn] class PartitionReplicaIdPartitioner(partitions: Int, replicas: Int) extends Partitioner { + + override def numPartitions: Int = partitions + (replicas * partitions) - implicit private def vectorDistanceFunction(name: String): DistanceFunction[Vector, Double] = name match { - case "bray-curtis" => VectorDistanceFunctions.brayCurtisDistance - case "canberra" => VectorDistanceFunctions.canberraDistance - case "correlation" => VectorDistanceFunctions.correlationDistance - case "cosine" => VectorDistanceFunctions.cosineDistance - case "euclidean" => VectorDistanceFunctions.euclideanDistance - case "inner-product" => VectorDistanceFunctions.innerProduct - case "manhattan" => VectorDistanceFunctions.manhattanDistance - case value => userDistanceFunction(value) + override def getPartition(key: Any): Int = { + val (partition, replica) = key.asInstanceOf[(Int, Int)] + partition + (replica * partitions) } +} - private def vectorApiAvailable: Boolean = try { - val _ = Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE - true - } catch { - case _: Throwable => false +private[knn] class IndexRunnable(uid: String, sparkContext: SparkContext, indexRdd: RDD[_]) extends Runnable { + + override def run(): Unit = { + sparkContext.setJobGroup(uid, "job group that holds the indices") + try { + indexRdd.count() + } catch { + case NonFatal(_) => + () + } finally { + sparkContext.clearJobGroup() + } } - private def userDistanceFunction[TVector, TDistance](name: String): DistanceFunction[TVector, TDistance] = - Try(Class.forName(name).getDeclaredConstructor().newInstance()).toOption - .collect { case f: DistanceFunction[TVector @unchecked, TDistance @unchecked] => f } - .getOrElse(throw new IllegalArgumentException(s"$name is not a valid distance functions.")) } -private[knn] class CleanupListener(dir: String, serializableConfiguration: SerializableConfiguration) - extends SparkListener - with Logging { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { +private[knn] trait ModelLogging extends Logging { +// protected def logInfo(partition: Int, message: String): Unit = logInfo(f"partition $partition%04d: $message") + protected def logInfo(partition: Int, message: String): Unit = println(f"partition $partition%04d: $message") - val path = new Path(dir) - val fileSystem = path.getFileSystem(serializableConfiguration.value) - - logInfo(s"Deleting files below $dir") - fileSystem.delete(path, true) - } +// protected def logInfo(partition: Int, replica: Int, message: String): Unit = logInfo(f"partition $replica%04d replica $partition%04d: $message") + protected def logInfo(partition: Int, replica: Int, message: String): Unit = println( + f"partition $partition%04d replica $replica%04d: $message" + ) } -/** Partitioner that uses precomputed partitions - * - * @param numPartitions - * number of partitions - */ -private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = key.asInstanceOf[Int] +private[knn] trait IndexServing extends ModelLogging with IndexType { + + protected def serve[ + TId: ClassTag, + TVector: ClassTag, + TItem <: Item[TId, TVector] with Product: ClassTag, + TDistance: ClassTag + ]( + uid: String, + indexRdd: RDD[TIndex[TId, TVector, TItem, TDistance]], + numPartitions: Int, + numReplicas: Int, + numThreads: Int + )(implicit + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance] + ): Map[PartitionAndReplica, InetSocketAddress] = { + + val sparkContext = indexRdd.sparkContext + val serializableConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) + + val keyedIndexRdd = indexRdd.flatMap { index => + Range.inclusive(0, numReplicas).map { replica => (TaskContext.getPartitionId(), replica) -> index } + } + + val replicaRdd = + if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaIdPartitioner(numPartitions, numReplicas)) + else keyedIndexRdd + + val driverHost = sparkContext.getConf.get("spark.driver.host") + val server = RegistrationServerFactory.create(driverHost, numPartitions, numReplicas) + server.start() + try { + val registrationAddress = server.address + + logInfo(s"Started registration server on ${registrationAddress.getHostName}:${registrationAddress.getPort}") + + val serverRdd = replicaRdd + .map { case ((partitionNum, replicaNum), index) => + val executorHost = SparkEnv.get.blockManager.blockManagerId.host + val server = indexServerFactory.create(executorHost, index, serializableConfiguration.value, numThreads) + + server.start() + + val serverAddress = server.address + + logInfo( + partitionNum, + replicaNum, + s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}" + ) + + logInfo( + partitionNum, + replicaNum, + s"registering replica at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + ) + try { + RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) + + logInfo(partitionNum, replicaNum, "awaiting requests") + + while (!TaskContext.get().isInterrupted() && !server.isTerminated()) { + Thread.sleep(500) + } + + logInfo( + partitionNum, + replicaNum, + s"Task canceled" + ) + } finally { + server.shutdown() + } + + true + } + .withResources(indexRdd.getResourceProfile()) + + // the count will never complete because the tasks start the index server + val thread = new Thread(new IndexRunnable(uid, sparkContext, serverRdd), s"knn-index-thread-$uid") + thread.setDaemon(true) + thread.start() + + server.awaitRegistrations() + } finally { + server.shutdown() + } + + } } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala new file mode 100644 index 00000000..92ee8937 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala @@ -0,0 +1,52 @@ +package com.github.jelmerk.spark.knn + +import java.net.InetSocketAddress + +import scala.util.Try + +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.IndexClientFactory +import org.apache.spark.sql.Row + +class QueryIterator[TId, TVector, TDistance]( + indices: Map[PartitionAndReplica, InetSocketAddress], + indexClientFactory: IndexClientFactory[TId, TVector, TDistance], + records: Iterator[Row], + batchSize: Int, + k: Int, + vectorCol: String, + partitionsCol: Option[String] +) extends Iterator[Row] { + + private var failed = false + private val client = indexClientFactory.create(indices) + + private val delegate = + if (records.isEmpty) Iterator[Row]() + else + records + .grouped(batchSize) + .map(batch => client.search(vectorCol, partitionsCol, batch, k)) + .reduce((a, b) => a ++ b) + + override def hasNext: Boolean = delegate.hasNext + + override def next(): Row = { + if (failed) { + throw new IllegalStateException("Client shutdown.") + } + try { + val result = delegate.next() + + if (!hasNext) { + client.shutdown() + } + result + } catch { + case t: Throwable => + Try(client.shutdown()) + failed = true + throw t + } + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala index 9d758a86..d4fc19ea 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala @@ -1,6 +1,7 @@ package com.github.jelmerk.spark.knn.bruteforce import java.io.InputStream +import java.net.InetSocketAddress import scala.reflect.ClassTag import scala.reflect.runtime.universe._ @@ -8,52 +9,89 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.IndexClientFactory import com.github.jelmerk.spark.knn._ +import org.apache.spark.SparkContext import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.StructType -/** Companion class for BruteForceSimilarityModel. - */ -object BruteForceSimilarityModel extends MLReadable[BruteForceSimilarityModel] { +private[bruteforce] trait BruteForceIndexType extends IndexType { + protected override type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = + BruteForceIndex[TId, TVector, TItem, TDistance] - private[knn] class BruteForceModelReader extends KnnModelReader[BruteForceSimilarityModel] { + protected override implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] = + ClassTag(classOf[TIndex[TId, TVector, TItem, TDistance]]) +} - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) +private[bruteforce] trait BruteForceIndexLoader extends IndexLoader with BruteForceIndexType { + protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex.loadFromInputStream(inputStream) +} - } +private[bruteforce] trait BruteForceModelCreator extends ModelCreator[BruteForceSimilarityModel] { + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): BruteForceSimilarityModel = + new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indices, + clientFactory + ) +} + +/** Companion class for BruteForceSimilarityModel. */ +object BruteForceSimilarityModel extends MLReadable[BruteForceSimilarityModel] { + + private[knn] class BruteForceModelReader + extends KnnModelReader[BruteForceSimilarityModel] + with BruteForceModelCreator + with BruteForceIndexLoader override def read: MLReader[BruteForceSimilarityModel] = new BruteForceModelReader } -/** Model produced by `BruteForceSimilarity`. - */ +/** Model produced by `BruteForceSimilarity`. */ abstract class BruteForceSimilarityModel extends KnnModelBase[BruteForceSimilarityModel] with KnnModelParams with MLWritable private[knn] class BruteForceSimilarityModelImpl[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] -) extends BruteForceSimilarityModel + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance +]( + override val uid: String, + val numPartitions: Int, + val numReplicas: Int, + val numThreads: Int, + val sparkContext: SparkContext, + val indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + val clientFactory: IndexClientFactory[TId, TVector, TDistance] +)(implicit val idTypeTag: TypeTag[TId], val vectorTypeTag: TypeTag[TVector]) + extends BruteForceSimilarityModel with KnnModelOps[ BruteForceSimilarityModel, TId, @@ -63,17 +101,21 @@ private[knn] class BruteForceSimilarityModelImpl[ BruteForceIndex[TId, TVector, TItem, TDistance] ] { - override def getNumPartitions: Int = numPartitions - - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) +// override implicit protected def idTypeTag: TypeTag[TId] = typeTag[TId] override def copy(extra: ParamMap): BruteForceSimilarityModel = { - val copied = new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + val copied = new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indexAddresses, + clientFactory + ) copyValues(copied, extra).setParent(parent) } - override def transformSchema(schema: StructType): StructType = typedTransformSchema[TId](schema) - override def write: MLWriter = new KnnModelWriter[ BruteForceSimilarityModel, TId, @@ -94,10 +136,10 @@ private[knn] class BruteForceSimilarityModelImpl[ * @param uid * identifier */ -class BruteForceSimilarity(override val uid: String) extends KnnAlgorithm[BruteForceSimilarityModel](uid) { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = - BruteForceIndex[TId, TVector, TItem, TDistance] +class BruteForceSimilarity(override val uid: String) + extends KnnAlgorithm[BruteForceSimilarityModel](uid) + with BruteForceModelCreator + with BruteForceIndexLoader { def this() = this(Identifiable.randomUID("brute_force")) @@ -110,26 +152,9 @@ class BruteForceSimilarity(override val uid: String) extends KnnAlgorithm[BruteF idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem] ): BruteForceIndex[TId, TVector, TItem, TDistance] = - BruteForceIndex[TId, TVector, TItem, TDistance]( - dimensions, - distanceFunction - ) - - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex.loadFromInputStream(inputStream) - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + BruteForceIndex[TId, TVector, TItem, TDistance](dimensions, distanceFunction) + override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] + : BruteForceIndex[TId, TVector, TItem, TDistance] = + BruteForceIndex.empty[TId, TVector, TItem, TDistance] } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala index 310d5baa..3e03cf81 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala @@ -1,6 +1,7 @@ package com.github.jelmerk.spark.knn.hnsw import java.io.InputStream +import java.net.InetSocketAddress import scala.reflect.ClassTag import scala.reflect.runtime.universe._ @@ -8,14 +9,78 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.hnsw._ +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.IndexClientFactory import com.github.jelmerk.spark.knn._ +import org.apache.spark.SparkContext import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.ml.util._ + +private[hnsw] trait HnswIndexType extends IndexType { + protected override type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = + HnswIndex[TId, TVector, TItem, TDistance] + + protected override implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] = + ClassTag(classOf[TIndex[TId, TVector, TItem, TDistance]]) + +} + +private[hnsw] trait HnswIndexLoader extends IndexLoader with HnswIndexType { + protected override def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): HnswIndex[TId, TVector, TItem, TDistance] = { + val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](inputStream) + index.resize(index.maxItemCount + minCapacity) + index + } +} + +private[hnsw] trait HnswModelCreator extends ModelCreator[HnswSimilarityModel] { + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): HnswSimilarityModel = + new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indices, + clientFactory + ) +} private[hnsw] trait HnswParams extends KnnAlgorithmParams with HnswModelParams { + /** Size of the dynamic list for the nearest neighbors (used during the search). Default: 10 + * + * @group param + */ + final val ef = new IntParam( + this, + "ef", + "size of the dynamic list for the nearest neighbors (used during the search)", + ParamValidators.gt(0) + ) + + /** @group getParam */ + final def getEf: Int = $(ef) + /** The number of bi-directional links created for every new element during construction. * * Default: 16 @@ -46,48 +111,21 @@ private[hnsw] trait HnswParams extends KnnAlgorithmParams with HnswModelParams { /** @group getParam */ final def getEfConstruction: Int = $(efConstruction) - setDefault(m -> 16, efConstruction -> 200) + setDefault(m -> 16, efConstruction -> 200, ef -> 10) } /** Common params for Hnsw and HnswModel. */ -private[hnsw] trait HnswModelParams extends KnnModelParams { - - /** Size of the dynamic list for the nearest neighbors (used during the search). Default: 10 - * - * @group param - */ - final val ef = new IntParam( - this, - "ef", - "size of the dynamic list for the nearest neighbors (used during the search)", - ParamValidators.gt(0) - ) - - /** @group getParam */ - final def getEf: Int = $(ef) - - setDefault(ef -> 10) -} +private[hnsw] trait HnswModelParams extends KnnModelParams /** Companion class for HnswSimilarityModel. */ object HnswSimilarityModel extends MLReadable[HnswSimilarityModel] { - private[hnsw] class HnswModelReader extends KnnModelReader[HnswSimilarityModel] { - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - } + private[hnsw] class HnswModelReader + extends KnnModelReader[HnswSimilarityModel] + with HnswIndexLoader + with HnswModelCreator override def read: MLReader[HnswSimilarityModel] = new HnswModelReader @@ -95,46 +133,46 @@ object HnswSimilarityModel extends MLReadable[HnswSimilarityModel] { /** Model produced by `HnswSimilarity`. */ -abstract class HnswSimilarityModel extends KnnModelBase[HnswSimilarityModel] with HnswModelParams with MLWritable { - - /** @group setParam */ - def setEf(value: Int): this.type = set(ef, value) - -} +abstract class HnswSimilarityModel extends KnnModelBase[HnswSimilarityModel] with HnswModelParams with MLWritable private[knn] class HnswSimilarityModelImpl[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] -) extends HnswSimilarityModel + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance +]( + override val uid: String, + val numPartitions: Int, + val numReplicas: Int, + val numThreads: Int, + val sparkContext: SparkContext, + val indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + val clientFactory: IndexClientFactory[TId, TVector, TDistance] +)(implicit val idTypeTag: TypeTag[TId], val vectorTypeTag: TypeTag[TVector]) + extends HnswSimilarityModel with KnnModelOps[HnswSimilarityModel, TId, TVector, TItem, TDistance, HnswIndex[TId, TVector, TItem, TDistance]] { - override def getNumPartitions: Int = numPartitions - - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) - override def copy(extra: ParamMap): HnswSimilarityModel = { - val copied = new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + val copied = new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indexAddresses, + clientFactory + ) copyValues(copied, extra).setParent(parent) } - override def transformSchema(schema: StructType): StructType = typedTransformSchema[TId](schema) - override def write: MLWriter = new KnnModelWriter[HnswSimilarityModel, TId, TVector, TItem, TDistance, HnswIndex[TId, TVector, TItem, TDistance]]( this ) - override protected def loadIndex(in: InputStream): HnswIndex[TId, TVector, TItem, TDistance] = { - val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](in) - index.ef = getEf - index - } + override protected def loadIndex(in: InputStream): HnswIndex[TId, TVector, TItem, TDistance] = + HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](in) + } /** Nearest neighbor search using the approximative hnsw algorithm. @@ -142,10 +180,11 @@ private[knn] class HnswSimilarityModelImpl[ * @param uid * identifier */ -class HnswSimilarity(override val uid: String) extends KnnAlgorithm[HnswSimilarityModel](uid) with HnswParams { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = - HnswIndex[TId, TVector, TItem, TDistance] +class HnswSimilarity(override val uid: String) + extends KnnAlgorithm[HnswSimilarityModel](uid) + with HnswParams + with HnswIndexLoader + with HnswModelCreator { def this() = this(Identifiable.randomUID("hnsw")) @@ -179,24 +218,7 @@ class HnswSimilarity(override val uid: String) extends KnnAlgorithm[HnswSimilari itemSerializer ) - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): HnswIndex[TId, TVector, TItem, TDistance] = { - val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](inputStream) - index.resize(index.maxItemCount + minCapacity) - index - } - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] + : HnswIndex[TId, TVector, TItem, TDistance] = + HnswIndex.empty[TId, TVector, TItem, TDistance] } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala index 39bdd78a..69636a75 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala @@ -2,11 +2,90 @@ package com.github.jelmerk.spark import java.io.{ObjectInput, ObjectOutput} -import com.github.jelmerk.knn.scalalike.ObjectSerializer -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} +import scala.language.implicitConversions +import scala.util.Try + +import com.github.jelmerk.knn.Jdk17DistanceFunctions +import com.github.jelmerk.knn.scalalike.{ + doubleBrayCurtisDistance, + doubleCanberraDistance, + doubleCorrelationDistance, + doubleCosineDistance, + doubleEuclideanDistance, + doubleInnerProduct, + doubleManhattanDistance, + floatBrayCurtisDistance, + floatCanberraDistance, + floatCorrelationDistance, + floatCosineDistance, + floatEuclideanDistance, + floatInnerProduct, + floatManhattanDistance, + DistanceFunction, + Item, + ObjectSerializer +} +import com.github.jelmerk.knn.scalalike.jdk17DistanceFunctions.{ + vectorFloat128BrayCurtisDistance, + vectorFloat128CanberraDistance, + vectorFloat128CosineDistance, + vectorFloat128EuclideanDistance, + vectorFloat128InnerProduct, + vectorFloat128ManhattanDistance +} +import com.github.jelmerk.server.index.{ + DenseVector, + DoubleArrayVector, + FloatArrayVector, + Result, + SearchRequest, + SparseVector +} +import com.github.jelmerk.serving.client.IndexClientFactory +import com.github.jelmerk.serving.server.IndexServerFactory +import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions +import org.apache.spark.ml.linalg.{DenseVector => SparkDenseVector, SparseVector => SparkSparseVector, Vector, Vectors} package object knn { + private[knn] case class IntDoubleArrayIndexItem(id: Int, vector: Array[Double]) extends Item[Int, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class LongDoubleArrayIndexItem(id: Long, vector: Array[Double]) extends Item[Long, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class StringDoubleArrayIndexItem(id: String, vector: Array[Double]) + extends Item[String, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class IntFloatArrayIndexItem(id: Int, vector: Array[Float]) extends Item[Int, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class LongFloatArrayIndexItem(id: Long, vector: Array[Float]) extends Item[Long, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) + extends Item[String, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class IntVectorIndexItem(id: Int, vector: Vector) extends Item[Int, Vector] { + override def dimensions: Int = vector.size + } + + private[knn] case class LongVectorIndexItem(id: Long, vector: Vector) extends Item[Long, Vector] { + override def dimensions: Int = vector.size + } + + private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extends Item[String, Vector] { + override def dimensions: Int = vector.size + } + private[knn] implicit object StringSerializer extends ObjectSerializer[String] { override def write(item: String, out: ObjectOutput): Unit = out.writeUTF(item) override def read(in: ObjectInput): String = in.readUTF() @@ -58,12 +137,12 @@ package object knn { private[knn] implicit object VectorSerializer extends ObjectSerializer[Vector] { override def write(item: Vector, out: ObjectOutput): Unit = item match { - case v: DenseVector => + case v: SparkDenseVector => out.writeBoolean(true) out.writeInt(v.size) v.values.foreach(out.writeDouble) - case v: SparseVector => + case v: SparkSparseVector => out.writeBoolean(false) out.writeInt(v.size) out.writeInt(v.indices.length) @@ -220,4 +299,231 @@ package object knn { } } + private[knn] implicit object IntVectorIndexServerFactory + extends IndexServerFactory[Int, Vector, IntVectorIndexItem, Double]( + extractVector, + convertIntId, + convertDoubleDistance + ) + + private[knn] implicit object LongVectorIndexServerFactory + extends IndexServerFactory[Long, Vector, LongVectorIndexItem, Double]( + extractVector, + convertLongId, + convertDoubleDistance + ) + + private[knn] implicit object StringVectorIndexServerFactory + extends IndexServerFactory[String, Vector, StringVectorIndexItem, Double]( + extractVector, + convertStringId, + convertDoubleDistance + ) + + private[knn] implicit object IntFloatArrayIndexServerFactory + extends IndexServerFactory[Int, Array[Float], IntFloatArrayIndexItem, Float]( + extractFloatArray, + convertIntId, + convertFloatDistance + ) + + private[knn] implicit object LongFloatArrayIndexServerFactory + extends IndexServerFactory[Long, Array[Float], LongFloatArrayIndexItem, Float]( + extractFloatArray, + convertLongId, + convertFloatDistance + ) + + private[knn] implicit object StringFloatArrayIndexServerFactory + extends IndexServerFactory[String, Array[Float], StringFloatArrayIndexItem, Float]( + extractFloatArray, + convertStringId, + convertFloatDistance + ) + + private[knn] implicit object IntDoubleArrayIndexServerFactory + extends IndexServerFactory[Int, Array[Double], IntDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertIntId, + convertDoubleDistance + ) + + private[knn] implicit object LongDoubleArrayIndexServerFactory + extends IndexServerFactory[Long, Array[Double], LongDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertLongId, + convertDoubleDistance + ) + + private[knn] implicit object StringDoubleArrayIndexServerFactory + extends IndexServerFactory[String, Array[Double], StringDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertStringId, + convertDoubleDistance + ) + + private[knn] implicit object IntVectorIndexClientFactory + extends IndexClientFactory[Int, Vector, Double]( + convertVector, + extractIntId, + extractDoubleDistance, + implicitly[Ordering[Double]] + ) + + private[knn] implicit object LongVectorIndexClientFactory + extends IndexClientFactory[Long, Vector, Double]( + convertVector, + extractLongId, + extractDoubleDistance, + implicitly[Ordering[Double]] + ) + + private[knn] implicit object StringVectorIndexClientFactory + extends IndexClientFactory[String, Vector, Double]( + convertVector, + extractStringId, + extractDoubleDistance, + implicitly[Ordering[Double]] + ) + + private[knn] implicit object IntFloatArrayIndexClientFactory + extends IndexClientFactory[Int, Array[Float], Float]( + convertFloatArray, + extractIntId, + extractFloatDistance, + implicitly[Ordering[Float]] + ) + + private[knn] implicit object LongFloatArrayIndexClientFactory + extends IndexClientFactory[Long, Array[Float], Float]( + convertFloatArray, + extractLongId, + extractFloatDistance, + implicitly[Ordering[Float]] + ) + + private[knn] implicit object StringFloatArrayIndexClientFactory + extends IndexClientFactory[String, Array[Float], Float]( + convertFloatArray, + extractStringId, + extractFloatDistance, + implicitly[Ordering[Float]] + ) + + private[knn] implicit object IntDoubleArrayIndexClientFactory + extends IndexClientFactory[Int, Array[Double], Double]( + convertDoubleArray, + extractIntId, + extractFloatDistance, + implicitly[Ordering[Double]] + ) + + private[knn] implicit object LongDoubleArrayIndexClientFactory + extends IndexClientFactory[Long, Array[Double], Double]( + convertDoubleArray, + extractLongId, + extractFloatDistance, + implicitly[Ordering[Double]] + ) + + private[knn] implicit object StringDoubleArrayIndexClientFactory + extends IndexClientFactory[String, Array[Double], Double]( + convertDoubleArray, + extractStringId, + extractFloatDistance, + implicitly[Ordering[Double]] + ) + + private[knn] def convertFloatArray(vector: Array[Float]): SearchRequest.Vector = + SearchRequest.Vector.FloatArrayVector(FloatArrayVector(vector)) + + private[knn] def convertDoubleArray(vector: Array[Double]): SearchRequest.Vector = + SearchRequest.Vector.DoubleArrayVector(DoubleArrayVector(vector)) + + private[knn] def convertVector(vector: Vector): SearchRequest.Vector = vector match { + case v: SparkDenseVector => SearchRequest.Vector.DenseVector(DenseVector(v.values)) + case v: SparkSparseVector => SearchRequest.Vector.SparseVector(SparseVector(vector.size, v.indices, v.values)) + } + + private[knn] def extractDoubleDistance(result: Result): Double = result.getDoubleDistance + + private[knn] def extractFloatDistance(result: Result): Float = result.getFloatDistance + + private[knn] def extractStringId(result: Result): String = result.getStringId + + private[knn] def extractLongId(result: Result): Long = result.getLongId + + private[knn] def extractIntId(result: Result): Int = result.getIntId + + private[knn] def extractFloatArray(request: SearchRequest): Array[Float] = request.vector.floatArrayVector + .map(_.values) + .orNull + + private[knn] def extractDoubleArray(request: SearchRequest): Array[Double] = request.vector.doubleArrayVector + .map(_.values) + .orNull + + private[knn] def extractVector(request: SearchRequest): Vector = + if (request.vector.isDenseVector) request.vector.denseVector.map { v => new SparkDenseVector(v.values) }.orNull + else request.vector.sparseVector.map { v => new SparkSparseVector(v.size, v.indices, v.values) }.orNull + + private[knn] def convertStringId(value: String): Result.Id = Result.Id.StringId(value) + private[knn] def convertLongId(value: Long): Result.Id = Result.Id.LongId(value) + private[knn] def convertIntId(value: Int): Result.Id = Result.Id.IntId(value) + + private[knn] def convertFloatDistance(value: Float): Result.Distance = Result.Distance.FloatDistance(value) + private[knn] def convertDoubleDistance(value: Double): Result.Distance = Result.Distance.DoubleDistance(value) + + implicit private[knn] def floatArrayDistanceFunction(name: String): DistanceFunction[Array[Float], Float] = + (name, vectorApiAvailable) match { + case ("bray-curtis", true) => vectorFloat128BrayCurtisDistance + case ("bray-curtis", _) => floatBrayCurtisDistance + case ("canberra", true) => vectorFloat128CanberraDistance + case ("canberra", _) => floatCanberraDistance + case ("correlation", _) => floatCorrelationDistance + case ("cosine", true) => vectorFloat128CosineDistance + case ("cosine", _) => floatCosineDistance + case ("euclidean", true) => vectorFloat128EuclideanDistance + case ("euclidean", _) => floatEuclideanDistance + case ("inner-product", true) => vectorFloat128InnerProduct + case ("inner-product", _) => floatInnerProduct + case ("manhattan", true) => vectorFloat128ManhattanDistance + case ("manhattan", _) => floatManhattanDistance + case (value, _) => userDistanceFunction(value) + } + + implicit private[knn] def doubleArrayDistanceFunction(name: String): DistanceFunction[Array[Double], Double] = + name match { + case "bray-curtis" => doubleBrayCurtisDistance + case "canberra" => doubleCanberraDistance + case "correlation" => doubleCorrelationDistance + case "cosine" => doubleCosineDistance + case "euclidean" => doubleEuclideanDistance + case "inner-product" => doubleInnerProduct + case "manhattan" => doubleManhattanDistance + case value => userDistanceFunction(value) + } + + implicit private[knn] def vectorDistanceFunction(name: String): DistanceFunction[Vector, Double] = name match { + case "bray-curtis" => VectorDistanceFunctions.brayCurtisDistance + case "canberra" => VectorDistanceFunctions.canberraDistance + case "correlation" => VectorDistanceFunctions.correlationDistance + case "cosine" => VectorDistanceFunctions.cosineDistance + case "euclidean" => VectorDistanceFunctions.euclideanDistance + case "inner-product" => VectorDistanceFunctions.innerProduct + case "manhattan" => VectorDistanceFunctions.manhattanDistance + case value => userDistanceFunction(value) + } + + private def vectorApiAvailable: Boolean = try { + val _ = Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE + true + } catch { + case _: Throwable => false + } + + private def userDistanceFunction[TVector, TDistance](name: String): DistanceFunction[TVector, TDistance] = + Try(Class.forName(name).getDeclaredConstructor().newInstance()).toOption + .collect { case f: DistanceFunction[TVector @unchecked, TDistance @unchecked] => f } + .getOrElse(throw new IllegalArgumentException(s"$name is not a valid distance functions.")) } diff --git a/hnswlib-spark/src/test/python/test_bruteforce.py b/hnswlib-spark/src/test/python/test_bruteforce.py index e405713d..a8ee5dd0 100644 --- a/hnswlib-spark/src/test/python/test_bruteforce.py +++ b/hnswlib-spark/src/test/python/test_bruteforce.py @@ -12,9 +12,8 @@ def test_bruteforce(spark): [3, Vectors.dense([0.2, 0.1])], ], ['row_id', 'features']) - bruteforce = BruteForceSimilarity(identifierCol='row_id', queryIdentifierCol='row_id', featuresCol='features', - distanceFunction='cosine', numPartitions=100, excludeSelf=False, - similarityThreshold=-1.0) + bruteforce = BruteForceSimilarity(identifierCol='row_id', featuresCol='features', + distanceFunction='cosine', numPartitions=2, numThreads=1) model = bruteforce.fit(df) diff --git a/hnswlib-spark/src/test/python/test_hnsw.py b/hnswlib-spark/src/test/python/test_hnsw.py index cb960588..58fc9dc0 100644 --- a/hnswlib-spark/src/test/python/test_hnsw.py +++ b/hnswlib-spark/src/test/python/test_hnsw.py @@ -13,7 +13,7 @@ def test_hnsw(spark): ], ['row_id', 'features']) hnsw = HnswSimilarity(identifierCol='row_id', featuresCol='features', distanceFunction='cosine', m=32, ef=5, k=5, - efConstruction=200, numPartitions=100, excludeSelf=False, similarityThreshold=-1.0) + efConstruction=200, numPartitions=2, numThreads=1) model = hnsw.fit(df) diff --git a/hnswlib-spark/src/test/python/test_integration.py b/hnswlib-spark/src/test/python/test_integration.py index 3e2b9eea..dc49d13e 100644 --- a/hnswlib-spark/src/test/python/test_integration.py +++ b/hnswlib-spark/src/test/python/test_integration.py @@ -11,7 +11,7 @@ def test_incremental_models(spark, tmp_path): [1, Vectors.dense([0.1, 0.2, 0.3])] ], ['id', 'features']) - hnsw1 = HnswSimilarity() + hnsw1 = HnswSimilarity(numPartitions=2, numThreads=1) model1 = hnsw1.fit(df1) @@ -21,7 +21,7 @@ def test_incremental_models(spark, tmp_path): [2, Vectors.dense([0.9, 0.1, 0.2])] ], ['id', 'features']) - hnsw2 = HnswSimilarity(initialModelPath=tmp_path.as_posix()) + hnsw2 = HnswSimilarity(numPartitions=2, numThreads=1, initialModelPath=tmp_path.as_posix()) model2 = hnsw2.fit(df2) diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala index e72b4550..2580e54c 100644 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala +++ b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala @@ -38,9 +38,11 @@ case class MinimalOutputRow[TId, TDistance](id: TId, neighbors: Seq[Neighbor[TId class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { - // for some reason kryo cannot serialize the hnswindex so configure it to make sure it never gets serialized override def conf: SparkConf = super.conf .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", "com.github.jelmerk.spark.HnswLibKryoRegistrator") + .set("spark.speculation", "false") + .set("spark.ui.enabled", "true") test("prepartitioned data") { @@ -49,12 +51,12 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { val hnsw = new HnswSimilarity() .setIdentifierCol("id") - .setQueryIdentifierCol("id") .setFeaturesCol("vector") .setPartitionCol("partition") .setQueryPartitionsCol("partitions") .setNumPartitions(2) .setNumReplicas(3) + .setNumThreads(1) .setK(10) val indexItems = Seq( @@ -63,7 +65,7 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { PrePartitionedInputRow(partition = 1, id = 3000000, vector = Vectors.dense(0.4300, 0.9891)) ).toDF() - val model = hnsw.fit(indexItems).setPredictionCol("neighbors").setEf(10) + val model = hnsw.fit(indexItems).setPredictionCol("neighbors") val queries = Seq( QueryRow(partitions = Seq(0), id = 123, vector = Vectors.dense(0.2400, 0.3891)) @@ -76,6 +78,8 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { .head result.neighbors.size should be(2) // it couldn't see 3000000 because we only query partition 0 + + model.destroy() } test("find neighbors") { @@ -188,33 +192,37 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { val scenarios = Table[String, Boolean, Double, DataFrame, DataFrame => Unit]( ("outputFormat", "excludeSelf", "similarityThreshold", "input", "validator"), - ("full", false, 1, denseVectorInput, denseVectorScenarioValidator), - ("minimal", false, 1, denseVectorInput, minimalDenseVectorScenarioValidator), - ("full", false, 0.1, denseVectorInput, similarityThresholdScenarioValidator), - ("full", false, 0.1, floatArrayInput, floatArraySimilarityThresholdScenarioValidator), - ("full", false, noSimilarityThreshold, doubleArrayInput, doubleArrayScenarioValidator), - ("full", false, noSimilarityThreshold, floatArrayInput, floatArrayScenarioValidator), - ("full", true, noSimilarityThreshold, denseVectorInput, excludeSelfScenarioValidator), - ("full", true, 1, sparseVectorInput, sparseVectorScenarioValidator) +// ("full", false, 1, denseVectorInput, denseVectorScenarioValidator), +// ("minimal", false, 1, denseVectorInput, minimalDenseVectorScenarioValidator), + ("full", false, 0.1, denseVectorInput, similarityThresholdScenarioValidator) // , +// ("full", false, 0.1, floatArrayInput, floatArraySimilarityThresholdScenarioValidator), +// ("full", false, noSimilarityThreshold, doubleArrayInput, doubleArrayScenarioValidator), +// ("full", false, noSimilarityThreshold, floatArrayInput, floatArrayScenarioValidator), +// ("full", true, noSimilarityThreshold, denseVectorInput, excludeSelfScenarioValidator), +// ("full", true, 1, sparseVectorInput, sparseVectorScenarioValidator) ) forAll(scenarios) { case (outputFormat, excludeSelf, similarityThreshold, input, validator) => val hnsw = new HnswSimilarity() .setIdentifierCol("id") - .setQueryIdentifierCol("id") .setFeaturesCol("vector") - .setNumPartitions(5) - .setNumReplicas(3) + .setNumPartitions(2) + .setNumReplicas(1) + .setNumThreads(1) .setK(10) - .setExcludeSelf(excludeSelf) - .setSimilarityThreshold(similarityThreshold) - .setOutputFormat(outputFormat) - val model = hnsw.fit(input).setPredictionCol("neighbors").setEf(10) + val model = hnsw.fit(input).setPredictionCol("neighbors") + + try { + val result = model.transform(input) - val result = model.transform(input) + result.show(false) - validator(result) + validator(result) + } finally { + model.destroy() + Thread.sleep(5000) + } } } @@ -225,10 +233,10 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { val hnsw = new HnswSimilarity() .setIdentifierCol("id") - .setQueryIdentifierCol("id") .setFeaturesCol("vector") .setPredictionCol("neighbors") - .setOutputFormat("minimal") + .setNumThreads(1) + .setNumPartitions(1) val items = Seq( InputRow(1000000, Array(0.0110f, 0.2341f)), diff --git a/project/plugins.sbt b/project/plugins.sbt index 7f9f2dcd..d5bf29cf 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,3 +3,6 @@ addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.0.1") addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.6") + +libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.11" \ No newline at end of file