From f8c874469793c26039399df056447bcc85c3ed9c Mon Sep 17 00:00:00 2001 From: Jacob Murphy Date: Thu, 23 Jan 2025 17:33:48 +0000 Subject: [PATCH 1/4] Add integration testing framework Signed-off-by: Jacob Murphy --- MODULE.bazel | 19 + testing/integration/BUILD | 53 ++ testing/integration/requirements.txt | 3 + testing/integration/stability_runner.py | 368 ++++++++ testing/integration/stability_test.py | 391 ++++++++ testing/integration/utils.py | 869 ++++++++++++++++++ .../vector_search_integration_test.py | 505 ++++++++++ 7 files changed, 2208 insertions(+) create mode 100644 testing/integration/BUILD create mode 100644 testing/integration/requirements.txt create mode 100755 testing/integration/stability_runner.py create mode 100644 testing/integration/stability_test.py create mode 100644 testing/integration/utils.py create mode 100644 testing/integration/vector_search_integration_test.py diff --git a/MODULE.bazel b/MODULE.bazel index 9acd783..c4c63f5 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -2,6 +2,8 @@ module( name = "com_google_valkeysearch", version = "1.0.0", ) + +bazel_dep(name = "protoc-gen-validate", version = "1.0.4.bcr.2") bazel_dep(name = "rules_cc", version = "0.0.16") bazel_dep(name = "abseil-cpp", version = "20240722.0.bcr.1", repo_name = "com_google_absl") bazel_dep(name = "protobuf", version = "29.2", repo_name = "com_google_protobuf") @@ -36,3 +38,20 @@ git_override( # Replace the commit hash (above) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main). # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). ) + +# Integration test dependencies +bazel_dep(name = "rules_python", version = "0.40.0", dev_dependency = True) +python = use_extension( + "@rules_python//python/extensions:python.bzl", + "python", + dev_dependency = True +) +python.toolchain(python_version = "3.12", is_default=True) + +pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip", dev_dependency = True) +pip.parse( + hub_name = "pip", + requirements_lock = "//testing/integration:requirements.txt", + python_version = "3.12" +) +use_repo(pip, "pip") \ No newline at end of file diff --git a/testing/integration/BUILD b/testing/integration/BUILD new file mode 100644 index 0000000..799bbc8 --- /dev/null +++ b/testing/integration/BUILD @@ -0,0 +1,53 @@ +load("@pip//:requirements.bzl", "requirement") + +package( + default_applicable_licenses=["//:license"], + default_visibility=["//testing:__subpackages__"], +) + +py_test( + name="stability_test", + srcs=["stability_test.py"], + deps=[ + ":stability_runner", + ":utils", + requirement("valkey"), + requirement("absl-py"), + ], + data=[ + "//src:valkeysearch", + ], + timeout="eternal", # 3600 seconds, for slow machines +) + +py_test( + name="vector_search_integration_test", + srcs=["vector_search_integration_test.py"], + deps=[ + ":utils", + requirement("valkey"), + requirement("absl-py"), + requirement("numpy"), + ], + data=[ + "//src:valkeysearch", + ], +) + +py_library( + name="stability_runner", + srcs=["stability_runner.py"], + deps=[ + ":utils", + requirement("valkey"), + ], +) + +py_library( + name="utils", + srcs=["utils.py"], + deps=[ + requirement("valkey"), + requirement("numpy"), + ], +) diff --git a/testing/integration/requirements.txt b/testing/integration/requirements.txt new file mode 100644 index 0000000..bf5f8db --- /dev/null +++ b/testing/integration/requirements.txt @@ -0,0 +1,3 @@ +valkey==6.0.2 +absl-py==2.1.0 +numpy==2.2.2 diff --git a/testing/integration/stability_runner.py b/testing/integration/stability_runner.py new file mode 100755 index 0000000..92c2415 --- /dev/null +++ b/testing/integration/stability_runner.py @@ -0,0 +1,368 @@ +"""ValkeyQuery stability test core.""" + +import logging +import os +import sys +import threading +import time +from typing import NamedTuple +import valkey +from testing.integration import utils + + +class MemtierProcessRunResult(NamedTuple): + """Results for a single memtier process run.""" + + name: str + total_ops: int + failures: int + halted: bool + runtime: float + + +class BackgroundTaskRunResult(NamedTuple): + """Results for a single background thread run.""" + + name: str + total_ops: int + failures: int + + +class StabilityRunResult(NamedTuple): + """Results for a single stability test run.""" + + # False if the test was unable to be performed. + successful_run: bool + memtier_results: list[MemtierProcessRunResult] + background_task_results: list[BackgroundTaskRunResult] + + +class StabilityTestConfig(NamedTuple): + """Configuration for a single stability test run.""" + + index_name: str + ports: tuple[int, ...] + index_type: str + vector_dimensions: int + bgsave_interval_sec: int + ftcreate_interval_sec: int + ftdropindex_interval_sec: int + flushdb_interval_sec: int + randomize_bg_job_intervals: bool + num_memtier_threads: int + num_memtier_clients: int + num_search_clients: int + insertion_mode: str + test_time_sec: int + test_timeout: int + keyspace_size: int + use_coordinator: bool + replica_count: int + repl_diskless_load: str + memtier_path: str = "" + + +class StabilityRunner: + """Stability test runner. + + Attributes: + config: The configuration for the test. + """ + + def __init__(self, config: StabilityTestConfig): + self.config = config + logging.basicConfig( + handlers=[ + logging.StreamHandler(stream=sys.stdout), + ], + level="DEBUG", + format=( + "%(asctime)s [%(levelname)s] (%(name)s) %(funcName)s: %(message)s" + ), + ) + + def run(self) -> StabilityRunResult: + """Runs the stability test, sending memtier commands and running background threads that perform valkey operations. + + Returns: + + Raises: + ValueError: + """ + try: + r = valkey.ValkeyCluster( + host="localhost", + port=self.config.ports[0], + startup_nodes=[ + valkey.cluster.ClusterNode("localhost", port) + for port in self.config.ports + ], + require_full_coverage=True, + socket_timeout=10, + ) + except valkey.exceptions.ConnectionError as e: + logging.error("Unable to connect to valkey, %s", e) + return StabilityRunResult( + successful_run=False, + memtier_results=[], + background_task_results=[], + ) + + try: + utils.drop_index(r=r, index_name=self.config.index_name) + except valkey.exceptions.ValkeyError: + pass + + if self.config.index_type == "HNSW": + utils.create_hnsw_index( + r=r, + index_name=self.config.index_name, + vector_dimensions=self.config.vector_dimensions, + ) + else: + utils.create_flat_index( + r=r, + index_name=self.config.index_name, + vector_dimensions=self.config.vector_dimensions, + ) + + threads: list[utils.RandomIntervalTask] = [] + index_state = utils.IndexState( + index_lock=threading.Lock(), ft_created=True + ) + if self.config.bgsave_interval_sec != 0: + threads.append( + utils.periodic_bgsave( + r, + self.config.bgsave_interval_sec, + self.config.randomize_bg_job_intervals, + ) + ) + + if self.config.ftcreate_interval_sec != 0: + threads.append( + utils.periodic_ftcreate( + r, + self.config.ftcreate_interval_sec, + self.config.randomize_bg_job_intervals, + self.config.index_name, + self.config.vector_dimensions, + self.config.index_type == "HNSW", + index_state, + ) + ) + + if self.config.ftdropindex_interval_sec != 0: + threads.append( + utils.periodic_ftdrop( + r, + self.config.ftdropindex_interval_sec, + self.config.randomize_bg_job_intervals, + self.config.index_name, + index_state, + ) + ) + + if self.config.flushdb_interval_sec != 0: + threads.append( + utils.periodic_flushdb( + r, + self.config.flushdb_interval_sec, + self.config.randomize_bg_job_intervals, + index_state, + self.config.use_coordinator, + ) + ) + + memtier_output_dir = os.environ["TEST_UNDECLARED_OUTPUTS_DIR"] + + insert_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_memtier_clients}" + " --random-data" + " -" + " --command='HSET __key__ embedding __data__ tag my_tag numeric 10'" + " --command-key-pattern=P" + f" -d {self.config.vector_dimensions*4}" + " --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_insert.json" + ) + delete_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_memtier_clients}" + " --random-data" + " -" + " --command='DEL __key__'" + " --command-key-pattern=P" + f" -d {self.config.vector_dimensions*4}" + " --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_del.json" + ) + expire_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_memtier_clients}" + " --random-data" + " -" + " --command='EXPIRE __key__ 1'" + " --command-key-pattern=P" + f" -d {self.config.vector_dimensions*4}" + " --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_expire.json" + ) + + if self.config.insertion_mode == "request_count": + keys_per_client = int( + self.config.keyspace_size + / self.config.num_memtier_clients + / self.config.num_memtier_threads + ) + logging.debug("%d keys per client needed", keys_per_client) + insert_command += f" -n {keys_per_client}" + delete_command += f" -n {keys_per_client}" + expire_command += f" -n {keys_per_client}" + elif self.config.insertion_mode == "time_interval": + insert_command += f" --test-time {self.config.test_time_sec}" + delete_command += f" --test-time {self.config.test_time_sec}" + expire_command += f" --test-time {self.config.test_time_sec}" + else: + raise ValueError( + f"Unknown insertion mode: {self.config.insertion_mode}" + ) + + search_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_search_clients}" + " -" + " --command='FT.SEARCH" + f" {self.config.index_name} " + '"(@tag:{my_tag} @numeric:[0 100])=>[KNN 3 @embedding $query_vector]"' + ' NOCONTENT PARAMS 2 "query_vector" __data__ DIALECT 2\' ' + f" --test-time={self.config.test_time_sec}" + f" -d {self.config.vector_dimensions*4}" + " --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_search.json" + ) + + ft_info_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_search_clients}" + " -" + f" --command='FT.INFO {self.config.index_name}'" + f" --test-time={self.config.test_time_sec}" + f" -d {self.config.vector_dimensions*4}" + f" --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_ftinfo.json" + ) + + ft_list_command = ( + f"{self.config.memtier_path}" + " --cluster-mode" + " -s localhost" + f" -p {self.config.ports[0]}" + f" -t {self.config.num_memtier_threads}" + f" -c {self.config.num_search_clients}" + " -" + " --command='FT._LIST'" + f" --test-time={self.config.test_time_sec}" + f" -d {self.config.vector_dimensions*4}" + " --json-out-file" + f" {memtier_output_dir}/{self.config.index_name}_memtier_ftlist.json" + ) + + logging.debug("insert_command: %s", insert_command) + logging.debug("delete_command: %s", delete_command) + logging.debug("expire_command: %s", expire_command) + logging.debug("search_command: %s", search_command) + logging.debug("ft_info_command: %s", ft_info_command) + logging.debug("ft_list_command: %s", ft_list_command) + + processes: list[utils.MemtierProcess] = [] + processes.append( + utils.MemtierProcess(command=insert_command, name="HSET") + ) + processes.append( + utils.MemtierProcess(command=delete_command, name="DEL") + ) + processes.append( + utils.MemtierProcess(command=expire_command, name="EXPIRE") + ) + processes.append( + utils.MemtierProcess( + command=search_command, + name="FT.SEARCH", + error_predicate=lambda err: err + != f"-Index with name '{self.config.index_name}' not found", + ) + ) + processes.append( + utils.MemtierProcess( + command=ft_info_command, + name="FT.INFO", + error_predicate=lambda err: err + != f"-Index with name '{self.config.index_name}' not found", + ) + ) + processes.append( + utils.MemtierProcess(command=ft_list_command, name="FT._LIST") + ) + + timeout_start = time.time() + while time.time() - timeout_start < self.config.test_timeout: + if all(p.done for p in processes): + logging.info("---===All processes finished===---") + break + for process in processes: + process.process_logs() + process.print_status() + time.sleep(1) + else: + logging.error("Timed out waiting for processes to finish") + logging.info("killing processes...") + for process in processes: + process.process.kill() + logging.error("Processes killed") + + for thread in threads: + thread.stop() + + return StabilityRunResult( + successful_run=True, + memtier_results=[ + MemtierProcessRunResult( + name=process.name, + total_ops=process.total_ops, + failures=process.failures, + halted=process.halted, + runtime=process.runtime, + ) + for process in processes + ], + background_task_results=[ + BackgroundTaskRunResult( + name=thread.name, + total_ops=thread.ops, + failures=thread.failures, + ) + for thread in threads + ], + ) diff --git a/testing/integration/stability_test.py b/testing/integration/stability_test.py new file mode 100644 index 0000000..55c468b --- /dev/null +++ b/testing/integration/stability_test.py @@ -0,0 +1,391 @@ +import logging +import os +import time + +import valkey +import valkey.cluster + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from testing.integration import utils +from testing.integration import stability_runner + +FLAGS = flags.FLAGS +flags.DEFINE_string("valkey_server_path", None, "Path to the Valkey server") +flags.DEFINE_string("valkey_cli_path", None, "Path to the Valkey CLI") +flags.DEFINE_string("memtier_path", None, "Path to the Memtier binary") + + +class StabilityTests(parameterized.TestCase): + + def setUp(self): + super().setUp() + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", + level=logging.DEBUG, + ) + self.valkey_server = None + + def tearDown(self): + if self.valkey_server: + for _, process in self.valkey_server.items(): + process.terminate() + super().tearDown() + + @parameterized.named_parameters( + dict( + testcase_name="hnsw_no_backfill_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_no_backfill", + ports=(8000, 8001, 8002), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=0, + ftdropindex_interval_sec=0, + flushdb_interval_sec=0, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_with_backfill_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8003, 8004, 8005), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="flat_no_backfill_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="flat_no_backfill", + ports=(8006, 8007, 8008), + index_type="FLAT", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=0, + ftdropindex_interval_sec=0, + flushdb_interval_sec=0, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="flat_with_backfill_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="flat_with_backfill", + ports=(8009, 8010, 8011), + index_type="FLAT", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_with_backfill_no_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8012, 8013, 8014), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=False, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_no_backfill_no_coordinator", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_no_backfill", + ports=(8015, 8016, 8017), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=0, + ftdropindex_interval_sec=0, + flushdb_interval_sec=0, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=False, + replica_count=0, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_with_backfill_coordinator_replica", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8018, 8019, 8020, 8021, 8022, 8023), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=1, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_with_backfill_no_coordinator_replica", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8024, 8025, 8026, 8027, 8028, 8029), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=False, + replica_count=1, + repl_diskless_load="swapdb", + ), + ), + dict( + testcase_name="hnsw_with_backfill_coordinator_repl_diskless_disabled", + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8030, 8031, 8032, 8033, 8034, 8035), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=True, + replica_count=1, + repl_diskless_load="disabled", + ), + ), + dict( + testcase_name=( + "hnsw_with_backfill_no_coordinator_repl_diskless_disabled" + ), + config=stability_runner.StabilityTestConfig( + index_name="hnsw_with_backfill", + ports=(8036, 8037, 8038, 8039, 8040, 8041), + index_type="HNSW", + vector_dimensions=100, + bgsave_interval_sec=15, + ftcreate_interval_sec=10, + ftdropindex_interval_sec=10, + flushdb_interval_sec=20, + randomize_bg_job_intervals=True, + num_memtier_threads=10, + num_memtier_clients=10, + num_search_clients=10, + insertion_mode="time_interval", + test_time_sec=60, + test_timeout=120, + keyspace_size=1000000, + use_coordinator=False, + replica_count=1, + repl_diskless_load="disabled", + ), + ), + ) + def test_valkeyquery_stability(self, config): + valkey_server_stdout_dir = os.environ["TEST_UNDECLARED_OUTPUTS_DIR"] + + if FLAGS.valkey_server_path is None: + raise ValueError( + "--test_arg=--valkey_server_path=/path/to/valkey_server " + "is required" + ) + if FLAGS.valkey_cli_path is None: + raise ValueError( + "--test_arg=--valkey_cli_path=/path/to/valkey_cli is required" + ) + if FLAGS.memtier_path is None: + raise ValueError( + "--test_arg=--memtier_path=/path/to/memtier is required" + ) + valkey_module_path = os.path.join( + os.environ["PWD"], + "src/libvalkeysearch.so", + ) + config = config._replace(memtier_path=FLAGS.memtier_path) + + self.valkey_server = utils.start_valkey_cluster( + FLAGS.valkey_server_path, + FLAGS.valkey_cli_path, + config.ports, + os.environ["TEST_TMPDIR"], + valkey_server_stdout_dir, + { + "loglevel": "debug", + "enable-debug-command": "yes", + "repl-diskless-load": config.repl_diskless_load, + # tripled, for slow machines + "cluster-node-timeout": "45000", + }, + { + f"{valkey_module_path}": "--threads 2 --log-level notice" + + (" --use-coordinator" if config.use_coordinator else "") + }, + config.replica_count, + ) + connected = False + for _ in range(10): + try: + valkey_conn = valkey.ValkeyCluster( + host="localhost", + port=config.ports[0], + startup_nodes=[ + valkey.cluster.ClusterNode("localhost", port) + for port in config.ports + ], + require_full_coverage=True, + ) + valkey_conn.ping() + connected = True + break + except valkey.exceptions.ConnectionError: + time.sleep(1) + + if not connected: + self.fail("Failed to connect to valkey server") + + results = stability_runner.StabilityRunner(config).run() + + if results is None: + self.fail("Failed to run stability test") + + for port, process in self.valkey_server.items(): + if process.poll(): + self.fail("a process died during test, port: %d", port) + + self.assertTrue( + results.successful_run, + msg="Expected stability test to be performed successfully", + ) + for result in results.memtier_results: + self.assertGreater( + result.total_ops, + 0, + msg=f"Expected positive total ops for memtier run {result.name}", + ) + self.assertEqual( + result.failures, + 0, + f"Expected zero failures for memtier run {result.name}", + ) + self.assertFalse( + result.halted, + msg=( + f"Expected memtier run {result.name} to not be halted (didn't " + "make progress for >10sec)" + ), + ) + for result in results.background_task_results: + self.assertGreater( + result.total_ops, + 0, + msg=f"Expected positive total ops for background task {result.name}", + ) + # BGSAVE will fail if another is ongoing. + if result.name == "BGSAVE": + pass + else: + self.assertEqual( + result.failures, + 0, + f"Expected zero failures for background task {result.name}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/testing/integration/utils.py b/testing/integration/utils.py new file mode 100644 index 0000000..61a501f --- /dev/null +++ b/testing/integration/utils.py @@ -0,0 +1,869 @@ +"""Utilities for ValkeySearch testing.""" + +import fcntl +import logging +import os +import random +import re +import subprocess +import threading +import time +from typing import Any, Callable, Dict, List, NamedTuple, TextIO +import numpy as np +import valkey +import valkey.exceptions + + +def start_valkey_process( + valkey_server_path: str, + port: int, + directory: str, + stdout_file: TextIO, + args: dict[str, str], + modules: dict[str, str], + password: str | None = None, +) -> subprocess.Popen[Any]: + command = f"{valkey_server_path} --port {port} --dir {directory}" + modules_args = [f'"--loadmodule {k} {v}"' for k, v in modules.items()] + args_str = " ".join([f"--{k} {v}" for k, v in args.items()] + modules_args) + command += " " + args_str + command = "ulimit -c unlimited && " + command + logging.info("Starting valkey process with command: %s", command) + + process = subprocess.Popen( + command, shell=True, stdout=stdout_file, stderr=stdout_file + ) + + connected = False + for i in range(10): + logging.info( + "Attempting to connect to Valkey @ port %d (try #%d)", port, i + ) + try: + valkey_conn = valkey.Valkey( + host="localhost", + port=port, + password=password, + socket_timeout=1000, + ) + valkey_conn.ping() + connected = True + break + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + valkey.exceptions.TimeoutError, + ): + time.sleep(1) + if not connected: + raise valkey.exceptions.ConnectionError( + f"Failed to connect to valkey server on port {port}" + ) + logging.info("Attempting to connect to Valkey: OK") + + return process + + +def start_valkey_cluster( + valkey_server_path: str, + valkey_cli_path: str, + ports: List[int], + directory: str, + stdout_directory: str, + args: Dict[str, str], + modules: Dict[str, str], + replica_count: int = 0, + password: str | None = None, +) -> Dict[int, subprocess.Popen[Any]]: + """Starts a valkey cluster. + + Starts a valkey cluster with the given ports and arguments, with zero replicas. + + Args: + valkey_server_path: + valkey_cli_path: + ports: + directory: + stdout_directory: + args: + + Returns: + Dictionary of port to valkey process. + """ + cluster_args = dict(args) + processes = {} + + for port in ports: + stdout_path = os.path.join(stdout_directory, f"{port}_stdout.txt") + stdout_file = open(stdout_path, "w") + node_dir = os.path.join(directory, f"nodes{port}") + cluster_args["cluster-enabled"] = "yes" + cluster_args["cluster-config-file"] = os.path.join( + node_dir, "nodes.conf" + ) + cluster_args["cluster-node-timeout"] = "10000" + os.mkdir(node_dir) + processes[port] = start_valkey_process( + valkey_server_path, + port, + node_dir, + stdout_file, + cluster_args, + modules, + password, + ) + + cli_stdout_path = os.path.join(stdout_directory, "valkey_cli_stdout.txt") + cli_stdout_file = open(cli_stdout_path, "w") + valkey_cli_args = [valkey_cli_path, "--cluster-yes", "--cluster", "create"] + for port in ports: + valkey_cli_args.append(f"127.0.0.1:{port}") + valkey_cli_args.extend(["--cluster-replicas", str(replica_count)]) + if password: + valkey_cli_args.extend(["-a", password]) + + logging.info("Creating valkey cluster with command: %s", valkey_cli_args) + + timeout = 60 + now = time.time() + while time.time() - now < timeout: + try: + subprocess.run( + valkey_cli_args, + check=True, + stdout=cli_stdout_file, + stderr=cli_stdout_file, + ) + break + except subprocess.CalledProcessError: + time.sleep(1) + + # This is also ugly, but we need to wait for the cluster to be ready. There + # doesn't seem to be a way to do that with the valkey-server, since it seems to + # be ready immediately, but returns an CLUSTERDOWN error when we try to search + # too early, even after checking with ping. + time.sleep(10) + + return processes + + +def create_hnsw_index( + r: valkey.ValkeyCluster, + index_name: str, + vector_dimensions: int, + vector_attribute_name="embedding", + target_nodes=valkey.ValkeyCluster.DEFAULT_NODE, +): + """Creates a new HNSW index. + + Args: + r: + index_name: + vector_dimensions: + """ + args = [ + "FT.CREATE", + index_name, + "SCHEMA", + vector_attribute_name, + "VECTOR", + "HNSW", + "12", # number of remaining arguments + "M", + 100, + "TYPE", + "FLOAT32", + "DIM", + vector_dimensions, + "DISTANCE_METRIC", + "COSINE", + "EF_CONSTRUCTION", + 5, + "EF_RUNTIME", + 10, + "tag", + "TAG", + "SEPARATOR", + ",", + "numeric", + "NUMERIC", + # "INITIAL_CAP", + # 15000, + ] + return r.execute_command(*args, target_nodes=target_nodes) + + +def create_flat_index( + r: valkey.ValkeyCluster, index_name: str, vector_dimensions: int +): + """Creates a new FLAT index. + + Args: + r: + index_name: + vector_dimensions: + """ + args = [ + "FT.CREATE", + index_name, + "SCHEMA", + "embedding", + "VECTOR", + "FLAT", + "6", # number of remaining arguments + "TYPE", + "FLOAT32", + "DIM", + vector_dimensions, + "DISTANCE_METRIC", + "COSINE", + "tag", + "TAG", + "SEPARATOR", + ",", + "numeric", + "NUMERIC", + ] + r.execute_command(*args) + + +def drop_index(r: valkey.ValkeyCluster, index_name: str): + args = [ + "FT.DROPINDEX", + index_name, + ] + r.execute_command(*args) + + +def fetch_ft_info(r: valkey.ValkeyCluster, index_name: str): + args = [ + "FT.INFO", + index_name, + ] + return r.execute_command(*args, target_nodes=r.ALL_NODES) + + +def flushdb(r: valkey.ValkeyCluster): + args = ["FLUSHDB", "SYNC"] + r.execute_command(*args) + + +def generate_deterministic_data(vector_dimensions: int, seed: int): + # Set a fixed seed value for reproducibility + np.random.seed(seed) + # Generate deterministic random data + data = np.random.rand(vector_dimensions).astype(np.float32).tobytes() + return data + + +def insert_vector( + r: valkey.ValkeyCluster, key: str, vector_dimensions: int, seed: int +): + vector = generate_deterministic_data(vector_dimensions, seed) + return r.hset( + key, + { + "embedding": vector, + "some_hash_key": "some_hash_key_value_" + key, + }, + ) + + +def insert_vectors_thread( + key_prefix: str, + num_vectors: int, + vector_dimensions: int, + host: str, + port: int, + seed: int, +): + r = valkey.Valkey(host=host, port=port) + for i in range(1, num_vectors): + insert_vector( + r=r, + key=(key_prefix + "_" + str(seed) + "_" + str(i)), + vector_dimensions=vector_dimensions, + seed=(i + seed * num_vectors), + ) + + +def insert_vectors( + host: str, + port: int, + num_threads: int, + vector_dimensions: int, + num_vectors: int, +): + """Inserts vectors into the index. + + Args: + host: + port: + num_threads: + vector_dimensions: + num_vectors: + + Returns: + """ + threads = [] + for i in range(1, num_threads): + thread = threading.Thread( + target=insert_vectors_thread, + args=( + "Thread-" + str(i), + num_vectors, + vector_dimensions, + host, + port, + i, + ), + ) + threads.append(thread) + return threads + + +def delete_vector(r: valkey.ValkeyCluster, key: str): + return r.delete(key) + + +def knn_search( + r: valkey.ValkeyCluster, index_name: str, vector_dimensions: int, seed: int +): + """KNN searches the index. + + Args: + r: + index_name: + vector_dimensions: + seed: + + Returns: + """ + vector = generate_deterministic_data(vector_dimensions, seed) + args = [ + "FT.SEARCH", + index_name, + "*=>[KNN 3 @embedding $vec EF_RUNTIME 1 AS score]", + "params", + 2, + "vec", + vector, + "DIALECT", + 2, + ] + return r.execute_command(*args, target_nodes=r.RANDOM) + + +def writer_queue_size(r: valkey.ValkeyCluster, index_name: str): + out = fetch_ft_info(r, index_name) + for index, item in enumerate(out): + if "mutation_queue_size" in str(item): + return int(str(out[index + 1])[2:-1]) + logging.error("Couldn't find mutation_queue_size") + exit(1) + + +def wait_for_empty_writer_queue_size( + r: valkey.ValkeyCluster, index_name: str, timeout=0 +): + """Wait for the writer queue size to hit zero. + + Args: + r: + index_name: + timeout: + """ + start = time.time() + while True: + try: + queue_size = writer_queue_size(r=r, index_name=index_name) + if queue_size == 0: + return + logging.info( + "Waiting for queue size to hit zero, current size: %d", + queue_size, + ) + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + ) as e: + logging.error("Error fetching FT.INFO: %s", e) + if timeout > 0 and time.time() - start > timeout: + logging.error("Timed out waiting for queue size to hit zero") + return + time.sleep(1) + + +class RandomIntervalTask: + """Randomly executes a task at a random interval. + + Used to inject (faulty) background operations into the test. + + Attributes: + stopped: + interval: + randomize: + stop_condition: + task: + ops: + failures: + name: + thread: + """ + + def __init__( + self, + name: str, + interval: int, + randomize: bool, + work_func: Callable[[], bool], + ): + stop_condition = threading.Condition() + self.stopped = False + self.interval = interval + self.randomize = randomize + self.stop_condition = stop_condition + self.task = work_func + self.ops = 0 + self.failures = 0 + self.name = name + + def stop(self): + if not self.thread: + logging.error("Thread not running") + return + with self.stop_condition: + self.stopped = True + self.stop_condition.notify() + self.thread.join() + + def run(self): + self.thread = threading.Thread(target=self.loop) + self.thread.start() + + def loop(self): + """ """ + with self.stop_condition: + while True: + modifier = 1 + if self.randomize: + modifier = random.random() + self.stop_condition.wait_for( + lambda: self.stopped, timeout=self.interval * modifier + ) + if self.stopped: + return + if not self.task(): + self.failures += 1 + self.ops += 1 + + +def periodic_bgsave_task( + r: valkey.ValkeyCluster, +) -> bool: + try: + logging.info(" Invoking background save") + r.bgsave(target_nodes=r.ALL_NODES) + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + ) as e: + logging.error(" encountered error: %s", e) + return False + return True + + +def periodic_bgsave( + r: valkey.ValkeyCluster, + interval_sec: int, + randomize: bool, +) -> RandomIntervalTask: + thread = RandomIntervalTask( + "BGSAVE", interval_sec, randomize, lambda: periodic_bgsave_task(r) + ) + thread.run() + return thread + + +class IndexState: + + def __init__(self, index_lock: threading.Lock, ft_created: bool): + self.index_lock = index_lock + self.ft_created = ft_created + + +def periodic_ftdrop_task( + r: valkey.ValkeyCluster, + index_name: str, + index_state: IndexState, +) -> bool: + with index_state.index_lock: + logging.info(" Invoking index drop") + try: + drop_index(r, index_name) + index_state.ft_created = False + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + ) as e: + if not index_state.ft_created and "not found" in str(e): + logging.debug(" got expected error: %s", e) + else: + logging.error(" got unexpected error: %s", e) + return False + return True + + +def periodic_ftdrop( + r: valkey.ValkeyCluster, + interval_sec: int, + random_interval: bool, + index_name: str, + index_state: IndexState, +) -> RandomIntervalTask: + thread = RandomIntervalTask( + "FT.DROPINDEX", + interval_sec, + random_interval, + lambda: periodic_ftdrop_task(r, index_name, index_state), + ) + thread.run() + return thread + + +def periodic_ftcreate_task( + r: valkey.ValkeyCluster, + index_name: str, + dimensions: int, + hnsw: bool, + index_state: IndexState, +) -> bool: + with index_state.index_lock: + try: + logging.info(" Invoking index creation") + if hnsw: + create_hnsw_index(r, index_name, dimensions) + else: + create_flat_index(r, index_name, dimensions) + index_state.ft_created = True + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + ) as e: + if index_state.ft_created and "already exists" in str(e): + logging.debug(" got expected error: %s", e) + else: + logging.error(" got unexpected error: %s", e) + return False + return True + + +def periodic_ftcreate( + r: valkey.ValkeyCluster, + interval_sec: int, + random_interval: bool, + index_name: str, + dimensions: int, + hnsw: bool, + index_state: IndexState, +) -> RandomIntervalTask: + thread = RandomIntervalTask( + "FT.CREATE", + interval_sec, + random_interval, + lambda: periodic_ftcreate_task( + r, index_name, dimensions, hnsw, index_state + ), + ) + thread.run() + return thread + + +def periodic_flushdb_task( + r: valkey.ValkeyCluster, + index_state: IndexState, + use_coordinator: bool, +) -> bool: + with index_state.index_lock: + logging.info(" Invoking flush DB") + try: + flushdb(r) + if not use_coordinator: + index_state.ft_created = False + except ( + valkey.exceptions.ConnectionError, + valkey.exceptions.ResponseError, + ) as e: + logging.error( + " got unexpected error during FLUSHDB: %s", e + ) + return False + return True + + +def periodic_flushdb( + r: valkey.ValkeyCluster, + interval_sec: int, + random_interval: bool, + index_state: IndexState, + use_coordinator: bool, +) -> RandomIntervalTask: + thread = RandomIntervalTask( + "FLUSHDB", + interval_sec, + random_interval, + lambda: periodic_flushdb_task(r, index_state, use_coordinator), + ) + thread.run() + return thread + + +def set_non_blocking(fd) -> None: + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + +def spawn_memtier_process(command: str) -> subprocess.Popen[Any]: + memtier_process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if memtier_process.stdout is not None: + set_non_blocking(memtier_process.stdout.fileno()) + if memtier_process.stderr is not None: + set_non_blocking(memtier_process.stderr.fileno()) + return memtier_process + + +class MemtierErrorLineInfo(NamedTuple): + run_number: int + percent_complete: float + runtime: float + threads: int + ops: int + ops_sec: float + avg_ops_sec: float + b_sec: int + avg_b_sec: int + latency: float + avg_latency: float + error: str | None + + +class MemtierProcess: + + def __init__( + self, + command: str, + name: str, + trailing_secs: int = 10, + error_predicate: Callable[[str], bool] | None = None, + ): + self.name = name + self.runtime = 0 + self.trailing_ops_sec = [] + self.failures = 0 + self.trailing_secs = trailing_secs + self.halted = False + self.process = spawn_memtier_process(command) + self.done = False + self.error_predicate = error_predicate + self.total_ops = 0 + self.avg_ops_sec = 0 + + def process_logs(self): + for line in self._process_memtier_subprocess_output(): + if ( + line.error is not None + and self.error_predicate is not None + and not self.error_predicate(line.error) + ): + continue + if line.error is not None: + logging.error( + "<%s> encountered error: %s", self.name, line.error + ) + self._add_line_to_stats(line) + + def _add_line_to_stats(self, line: MemtierErrorLineInfo): + if line.error is not None: + self.failures += 1 + else: + self.runtime = line.runtime + self.trailing_ops_sec.insert(0, line.ops_sec) + if len(self.trailing_ops_sec) > self.trailing_secs: + self.trailing_ops_sec.pop() + if self.trailing_ops_sec: + trailing_ops_sec = sum(self.trailing_ops_sec) / len( + self.trailing_ops_sec + ) + if ( + trailing_ops_sec == 0 + and len(self.trailing_ops_sec) == self.trailing_secs + ): + self.halted = True + self.total_ops = line.ops + self.avg_ops_sec = line.avg_ops_sec + + def print_status(self): + if self.process.poll() is not None and not self.done: + logging.info( + "<%s> - \tState: Exit Code %d,\tRuntime: %d,\ttotal ops:" + " %d,\tops/s(latest): %d,\tavg ops/s(lifetime): %d", + self.name, + self.process.returncode, + self.runtime, + self.total_ops, + self.trailing_ops_sec[0] if self.trailing_ops_sec else 0, + self.avg_ops_sec, + ) + self.done = True + if self.done: + return + if self.trailing_ops_sec: + trailing_ops_sec = sum(self.trailing_ops_sec) / len( + self.trailing_ops_sec + ) + logging.info( + "<%s> - \tState: Running,\tRuntime: %d,\ttotal ops:" + " %d,\tops/s(latest): %d,\tavg ops/s(lifetime): %d,\tavg" + " ops/s(10s): %d", + self.name, + self.runtime, + self.total_ops, + self.trailing_ops_sec[0], + self.avg_ops_sec, + trailing_ops_sec, + ) + return + logging.info("<%s> - \tState: Waiting for output", self.name) + + def _process_memtier_subprocess_output(self): + try: + parsed_lines = [] + while True: + if self.process.stderr is None: + break + stderr = self.process.stderr.readline() + if stderr: + stderr = stderr.decode("utf-8") + error_line_info = parse_memtier_error_line(stderr) + if error_line_info is not None: + parsed_lines.append(error_line_info) + else: + logging.info( + "<%s> stderr: %s", self.name, stderr.strip() + ) + else: + break + while True: + if self.process.stdout is None: + break + stdout = self.process.stdout.readline() + if stdout: + stdout = stdout.decode("utf-8") + logging.info("<%s> stdout: %s", self.name, stdout.strip()) + else: + break + return parsed_lines + except IOError: + pass + + +def parse_memtier_error_line(line: str): + progress_pattern = ( + r"\[RUN #(\d+)" + r" ([\d\.]+)%?,\s+([\d\.]+)\s+secs\]\s+([\d\.]+)\s+threads:\s+(\d+)\s+ops,\s+([\d\.]+)\s+\(avg:\s+([\d\.]+)\)\s+ops\/sec,\s+([\d\.]+[KMG]B\/sec)\s+\(avg:\s+(\d+\.\d+[KMG]?B\/sec)\),\s+(-nan|[\d\.]+)\s+\(avg:\s+(\d+\.\d+)\)\s+msec\s+latency" + ) + match = re.search(progress_pattern, line) + + if match: + run_number = int(match.group(1)) + percent_complete = float(match.group(2)) + runtime = float(match.group(3)) + threads = int(match.group(4)) + ops = int(match.group(5)) + ops_sec = float(match.group(6)) + avg_ops_sec = float(match.group(7)) + b_sec = match.group(8) + avg_b_sec = match.group(9) + latency = float(match.group(10)) + avg_latency = float(match.group(11)) + return MemtierErrorLineInfo( + run_number=run_number, + percent_complete=percent_complete, + runtime=runtime, + threads=threads, + ops=ops, + ops_sec=ops_sec, + avg_ops_sec=avg_ops_sec, + b_sec=b_sec, + avg_b_sec=avg_b_sec, + latency=latency, + avg_latency=avg_latency, + error=None, + ) + else: + # See if it matches the error pattern + error_pattern = r"server [\d\.]+:\d+ handle error response: (.*)" + match = re.search(error_pattern, line) + if match: + return MemtierErrorLineInfo( + run_number=0, + percent_complete=0, + runtime=0, + threads=0, + ops=0, + ops_sec=0, + avg_ops_sec=0, + b_sec=0, + avg_b_sec=0, + latency=0, + avg_latency=0, + error=match.group(1), + ) + return None + + +def connect_to_valkey_cluster( + startup_nodes: List[valkey.cluster.ClusterNode], + require_full_coverage: bool = True, + password: str | None = None, + attempts: int = 10, + connection_class=valkey.connection.Connection, +) -> valkey.ValkeyCluster: + """Connects to a valkey cluster, retrying if necessary. + + Args: + startup_nodes: List of cluster nodes to connect to. + require_full_coverage: Whether to require full coverage of the cluster. + + Returns: + Valkey cluster connection or None if connection failed. + """ + if attempts <= 0: + raise ValueError("attempts must be > 0") + + while attempts > 0: + attempts -= 1 + try: + valkey_conn = valkey.cluster.ValkeyCluster.from_url( + url="valkey://{}:{}".format( + startup_nodes[0].host, startup_nodes[0].port + ), + password=password, + connection_class=connection_class, + startup_nodes=startup_nodes, + require_full_coverage=require_full_coverage, + ) + valkey_conn.ping() + return valkey_conn + except valkey.exceptions.ConnectionError as e: + if attempts == 0: + raise e + logging.info("Failed to connect to valkey cluster, retrying...") + time.sleep(1) + + assert False diff --git a/testing/integration/vector_search_integration_test.py b/testing/integration/vector_search_integration_test.py new file mode 100644 index 0000000..0e034d1 --- /dev/null +++ b/testing/integration/vector_search_integration_test.py @@ -0,0 +1,505 @@ +import difflib +import os +import pprint +import sys +import time +from typing import Any, List + +import numpy as np +import valkey +import valkey.cluster + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from testing.integration import utils + + +FLAGS = flags.FLAGS +flags.DEFINE_string("valkey_server_path", None, "Path to the Valkey server") +flags.DEFINE_string("valkey_cli_path", None, "Path to the Valkey CLI") + + +def generate_test_vector(dimensions: int, data: int): + vector = np.zeros(dimensions).astype(np.float32) + vector[0] = np.float32(1) + vector[1] = np.float32(data) + return vector + + +class VSSOutput: + """Helper class to parse VSS output.""" + + def __init__( + self, output: Any, embedding_attribute_name: str = "embedding" + ): + self.embedding_attribute_name = bytes(embedding_attribute_name, "utf-8") + if not output: + return + self.count = output[0] + self.keys = dict() + if len(output) < 3 or not isinstance(output[2], List): + # NOCONTENT/RETURN 0 + for i in range(1, len(output)): + self.keys[output[i]] = dict() + return + + for i in range(1, len(output), 2): + attrs = output[i + 1] + attrs_map = dict() + for j in range(0, len(attrs), 2): + attrs_map[attrs[j]] = attrs[j + 1] + if self.embedding_attribute_name in attrs_map: + attrs_map[self.embedding_attribute_name] = np.frombuffer( + attrs_map[self.embedding_attribute_name], dtype=np.float32 + ) + self.keys[output[i]] = attrs_map + + def __eq__(self, other): + if self.count != other.count: + return False + if self.keys.keys() != other.keys.keys(): + return False + for k in self.keys: + if self.keys[k].keys() != other.keys[k].keys(): + return False + for attr in self.keys[k]: + if attr == self.embedding_attribute_name: + if not np.allclose( + self.keys[k][attr], + other.keys[k][attr], + ): + return False + else: + if self.keys[k][attr] != other.keys[k][attr]: + return False + return True + + def __str__(self): + return f"count: {self.count}\nkeys: {pprint.pformat(self.keys)}" + + +class VSSTestCase(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.addTypeEqualityFunc(VSSOutput, "assertVSSOutputEqual") + + def assertVSSOutputEqual(self, a, b, msg=None): + if a != b: + diff_msg = "\n" + "\n".join( + difflib.ndiff( + str(a).splitlines(), + str(b).splitlines(), + ) + ) + self.fail(f"VSSOutput not equal: {diff_msg}") + + +class VectorSearchIntegrationTest(VSSTestCase): + # Start the valkey cluster once for all tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + valkey_server_stdout_dir = os.environ["TEST_UNDECLARED_OUTPUTS_DIR"] + + if FLAGS.valkey_server_path is None: + raise ValueError( + "--test_arg=--valkey_server_path=/path/to/valkey_server " + "is required" + ) + if FLAGS.valkey_cli_path is None: + raise ValueError( + "--test_arg=--valkey_cli_path=/path/to/valkey_cli is " + "required" + ) + + cls.valkey_server = utils.start_valkey_cluster( + FLAGS.valkey_server_path, + FLAGS.valkey_cli_path, + [6379, 6380, 6381], + os.environ["TEST_TMPDIR"], + valkey_server_stdout_dir, + { + "loglevel": "debug", + "enable-debug-command": "yes", + "repl-diskless-load": "swapdb", + # tripled, to handle slow test environments + "cluster-node-timeout": "45000", + }, + { + f"{os.path.join(os.environ['PWD'], 'src/libvalkeysearch.so')}": "--threads 2 --log-level notice --use-coordinator", + }, + 0, + ) + cls.valkey_conn = utils.connect_to_valkey_cluster( + [ + valkey.cluster.ClusterNode("localhost", port) + for port in [6379, 6380, 6381] + ], + True, + ) + if not cls.valkey_conn: + cls.fail("Failed to connect to valkey cluster") + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + def tearDown(self): + for index in self.valkey_conn.execute_command("FT._LIST"): + self.valkey_conn.execute_command("FT.DROPINDEX", index) + time.sleep(1) + self.valkey_conn.execute_command( + "FLUSHDB", target_nodes=self.valkey_conn.ALL_NODES + ) + for port, process in self.valkey_server.items(): + if process.poll(): + self.fail("a process died during test, port: %d", port) + try: + valkey.Valkey(port=port).ping() + except Exception as e: # pylint: disable=broad-except + self.fail(f"Failed to ping valkey on port {port}: {e}") + + super().tearDown() + + def test_create_and_drop_index(self): + self.assertEqual( + b"OK", + utils.create_hnsw_index( + self.valkey_conn, + "test_index", + 100, + "embedding", + target_nodes=valkey.ValkeyCluster.RANDOM, + ), + ) + time.sleep(1) + + with self.assertRaises(valkey.exceptions.ResponseError) as e: + utils.create_hnsw_index( + self.valkey_conn, + "test_index", + 100, + "embedding", + target_nodes=valkey.ValkeyCluster.RANDOM, + ) + self.assertEqual( + "Index test_index already exists.", + e.exception.args[0], + ) + + self.assertEqual( + [b"test_index"], + self.valkey_conn.execute_command("FT._LIST"), + ) + + self.assertEqual( + b"OK", + self.valkey_conn.execute_command("FT.DROPINDEX", "test_index"), + ) + time.sleep(1) + + with self.assertRaises(valkey.exceptions.ResponseError) as e: + self.valkey_conn.execute_command("FT.DROPINDEX", "test_index") + self.assertEqual( + "Index with name 'test_index' not found", + e.exception.args[0], + ) + + @parameterized.named_parameters( + dict( + testcase_name="index_not_exixts_no_content", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="not_a_real_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error="Index with name 'not_a_real_index' not found", + expected_result=None, + no_content=True, + ), + ), + dict( + testcase_name="attribute_not_exixts_no_content", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="not_a_real_attribute", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error="Index field `not_a_real_attribute` not exists", + expected_result=None, + no_content=True, + ), + ), + dict( + testcase_name="index_not_exixts", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="not_a_real_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error="Index with name 'not_a_real_index' not found", + expected_result=None, + no_content=False, + ), + ), + dict( + testcase_name="attribute_not_exixts", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="not_a_real_attribute", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error="Index field `not_a_real_attribute` not exists", + expected_result=None, + no_content=False, + ), + ), + dict( + testcase_name="happy_case_no_content", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error=None, + expected_result=[3, b"2", b"1", b"0"], + no_content=True, + ), + ), + dict( + testcase_name="happy_case_returns_0", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=[], + expected_error=None, + expected_result=[3, b"2", b"1", b"0"], + no_content=False, + ), + ), + dict( + testcase_name="happy_case_return_score", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=["score"], + expected_error=None, + expected_result=[ + 3, + b"2", + [b"score", b"0.552786409855"], + b"1", + [b"score", b"0.292893230915"], + b"0", + [b"score", b"0"], + ], + no_content=False, + ), + ), + dict( + testcase_name="happy_case_all_returns", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=None, + expected_error=None, + expected_result=[ + 3, + b"2", + [ + b"score", + b"0.552786409855", + b"embedding", + generate_test_vector(100, 2).tobytes(), + b"numeric", + b"2", + b"tag", + b"2", + b"not_indexed", + b"2", + ], + b"1", + [ + b"score", + b"0.292893230915", + b"embedding", + generate_test_vector(100, 1).tobytes(), + b"numeric", + b"1", + b"tag", + b"1", + b"not_indexed", + b"1", + ], + b"0", + [ + b"score", + b"0", + b"embedding", + generate_test_vector(100, 0).tobytes(), + b"numeric", + b"0", + b"tag", + b"0", + b"not_indexed", + b"0", + ], + ], + no_content=False, + ), + ), + dict( + testcase_name="happy_case_just_embeddings", + config=dict( + index_name="test_index", + vector_attribute_name="embedding", + search_index="test_index", + vector_search_attribute="embedding", + search_vector=generate_test_vector(100, 0), + filter="*", + knn=3, + score_as="score", + returns=["embedding"], + expected_error=None, + expected_result=[ + 3, + b"2", + [ + b"embedding", + generate_test_vector(100, 2).tobytes(), + ], + b"1", + [ + b"embedding", + generate_test_vector(100, 1).tobytes(), + ], + b"0", + [ + b"embedding", + generate_test_vector(100, 0).tobytes(), + ], + ], + no_content=False, + ), + ), + ) + def test_vector_search(self, config): + self.maxDiff = None + dimensions = 100 + self.assertEqual( + b"OK", + utils.create_hnsw_index( + self.valkey_conn, + config["index_name"], + dimensions, + config["vector_attribute_name"], + ), + ) + time.sleep(1) + for data in range(100): + vector = generate_test_vector(dimensions, data) + self.assertEqual( + 4, + self.valkey_conn.hset( + str(data), + mapping={ + "embedding": vector.tobytes(), + "tag": str(data), + "numeric": str(data), + "not_indexed": str(data), + }, + ), + ) + + args = [ + "FT.SEARCH", + config["search_index"], + ( + f'{config["filter"]}=>[KNN' + f' {config["knn"]} @{config["vector_search_attribute"]} $vec' + f' EF_RUNTIME 1 AS {config["score_as"]}]' + ), + "params", + 2, + "vec", + config["search_vector"].tobytes(), + "DIALECT", + 2, + ] + if config["no_content"]: + args.append("NOCONTENT") + if config["returns"] is not None: + args.extend( + ["RETURN", str(len(config["returns"]))] + config["returns"] + ) + + if config["expected_error"] is not None: + with self.assertRaises(valkey.exceptions.ResponseError) as e: + self.valkey_conn.execute_command( + *args, target_nodes=self.valkey_conn.RANDOM + ) + self.assertEqual( + config["expected_error"], + e.exception.args[0], + ) + else: + got = VSSOutput( + self.valkey_conn.execute_command( + *args, target_nodes=self.valkey_conn.RANDOM + ), + embedding_attribute_name=config["vector_search_attribute"], + ) + want = VSSOutput( + config["expected_result"], + embedding_attribute_name=config["vector_search_attribute"], + ) + self.assertEqual(want, got) + + +if __name__ == "__main__": + absltest.main() From 00577e1cf5e8fc05bbf9cd9c2b988f2b7bc8b251 Mon Sep 17 00:00:00 2001 From: Jacob Murphy Date: Thu, 23 Jan 2025 17:39:42 +0000 Subject: [PATCH 2/4] Add newline to MODULE.bazel Signed-off-by: Jacob Murphy --- MODULE.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MODULE.bazel b/MODULE.bazel index ecfae11..840ebca 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -55,4 +55,4 @@ pip.parse( requirements_lock = "//testing/integration:requirements.txt", python_version = "3.12" ) -use_repo(pip, "pip") \ No newline at end of file +use_repo(pip, "pip") From a37952212cc5328e920e28d04b1f4a914127bfed Mon Sep 17 00:00:00 2001 From: Jacob Murphy Date: Thu, 23 Jan 2025 17:52:07 +0000 Subject: [PATCH 3/4] Add integration tests to DEVELOPER.md Signed-off-by: Jacob Murphy --- DEVELOPER.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/DEVELOPER.md b/DEVELOPER.md index 6176080..81eab0e 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -160,6 +160,22 @@ you can run the test with `--strategy=TestRunner=local`, e.g.: bazel test //testing:ft_search_test --strategy=TestRunner=local --run_under=/some/path/foobar.sh ``` +### Integration Testing + +To run any integration tests, you will need to have a local build of both valkey-server and valkey-cli. You can retrieve these by downloading the [Valkey source code and building it locally](https://github.com/valkey-io/valkey/). + +Once you have downloaded and built Valkey, you can run the integration tests: + +``` +bazel test //testing/integration:vector_search_integration_test --test_arg=--valkey_server_path=/path/to/valkey-server --test_arg=--valkey_cli_path=/path/to/valkey-cli +``` + +Additionally, it is recommended to run the stability test suite, which requires a local build of [Memtier](https://github.com/RedisLabs/memtier_benchmark): + +``` +bazel test //testing/integration:stability_test --test_arg=--valkey_server_path=/path/to/valkey-server --test_arg=--valkey_cli_path=/path/to/valkey-cli --test_arg=--memtier_path=/path/to/memtier_benchmark --test_output=streamed +``` + ## Loading ValkeySearch is compatible with any version of Valkey and can also be loaded into Redis versions 7.0 and 7.2. To load the module, execute the following command: From d28cf7d64be1f2e4cba794a69dff9f904a9636eb Mon Sep 17 00:00:00 2001 From: Jacob Murphy Date: Wed, 29 Jan 2025 00:40:08 +0000 Subject: [PATCH 4/4] Apply review feedback Signed-off-by: Jacob Murphy --- testing/integration/stability_runner.py | 43 ++- testing/integration/stability_test.py | 15 +- testing/integration/utils.py | 292 +++++++++++------- .../vector_search_integration_test.py | 47 +-- 4 files changed, 239 insertions(+), 158 deletions(-) diff --git a/testing/integration/stability_runner.py b/testing/integration/stability_runner.py index 92c2415..b337f5c 100755 --- a/testing/integration/stability_runner.py +++ b/testing/integration/stability_runner.py @@ -90,7 +90,7 @@ def run(self) -> StabilityRunResult: ValueError: """ try: - r = valkey.ValkeyCluster( + client = valkey.ValkeyCluster( host="localhost", port=self.config.ports[0], startup_nodes=[ @@ -109,22 +109,31 @@ def run(self) -> StabilityRunResult: ) try: - utils.drop_index(r=r, index_name=self.config.index_name) + utils.drop_index(client=client, index_name=self.config.index_name) except valkey.exceptions.ValkeyError: pass + attributes = { + "tag": utils.TagDefinition(), + "numeric": utils.NumericDefinition(), + } if self.config.index_type == "HNSW": - utils.create_hnsw_index( - r=r, - index_name=self.config.index_name, - vector_dimensions=self.config.vector_dimensions, - ) + attributes.update({ + "embedding": utils.HNSWVectorDefinition( + vector_dimensions=self.config.vector_dimensions + ) + }) else: - utils.create_flat_index( - r=r, - index_name=self.config.index_name, - vector_dimensions=self.config.vector_dimensions, - ) + attributes.update({ + "embedding": utils.HNSWVectorDefinition( + vector_dimensions=self.config.vector_dimensions + ), + }) + utils.create_index( + client=client, + index_name=self.config.index_name, + attributes=attributes, + ) threads: list[utils.RandomIntervalTask] = [] index_state = utils.IndexState( @@ -133,7 +142,7 @@ def run(self) -> StabilityRunResult: if self.config.bgsave_interval_sec != 0: threads.append( utils.periodic_bgsave( - r, + client, self.config.bgsave_interval_sec, self.config.randomize_bg_job_intervals, ) @@ -142,12 +151,12 @@ def run(self) -> StabilityRunResult: if self.config.ftcreate_interval_sec != 0: threads.append( utils.periodic_ftcreate( - r, + client, self.config.ftcreate_interval_sec, self.config.randomize_bg_job_intervals, self.config.index_name, self.config.vector_dimensions, - self.config.index_type == "HNSW", + attributes, index_state, ) ) @@ -155,7 +164,7 @@ def run(self) -> StabilityRunResult: if self.config.ftdropindex_interval_sec != 0: threads.append( utils.periodic_ftdrop( - r, + client, self.config.ftdropindex_interval_sec, self.config.randomize_bg_job_intervals, self.config.index_name, @@ -166,7 +175,7 @@ def run(self) -> StabilityRunResult: if self.config.flushdb_interval_sec != 0: threads.append( utils.periodic_flushdb( - r, + client, self.config.flushdb_interval_sec, self.config.randomize_bg_job_intervals, index_state, diff --git a/testing/integration/stability_test.py b/testing/integration/stability_test.py index 55c468b..0635568 100644 --- a/testing/integration/stability_test.py +++ b/testing/integration/stability_test.py @@ -25,12 +25,11 @@ def setUp(self): format="%(asctime)s - %(levelname)s - %(message)s", level=logging.DEBUG, ) - self.valkey_server = None + self.valkey_cluster_under_test = None def tearDown(self): - if self.valkey_server: - for _, process in self.valkey_server.items(): - process.terminate() + if self.valkey_cluster_under_test: + self.valkey_cluster_under_test.terminate() super().tearDown() @parameterized.named_parameters( @@ -299,7 +298,7 @@ def test_valkeyquery_stability(self, config): ) config = config._replace(memtier_path=FLAGS.memtier_path) - self.valkey_server = utils.start_valkey_cluster( + self.valkey_cluster_under_test = utils.start_valkey_cluster( FLAGS.valkey_server_path, FLAGS.valkey_cli_path, config.ports, @@ -344,9 +343,9 @@ def test_valkeyquery_stability(self, config): if results is None: self.fail("Failed to run stability test") - for port, process in self.valkey_server.items(): - if process.poll(): - self.fail("a process died during test, port: %d", port) + terminated = self.valkey_cluster_under_test.get_terminated_servers() + if (terminated): + self.fail(f"Valkey servers died during test, ports: {terminated}") self.assertTrue( results.successful_run, diff --git a/testing/integration/utils.py b/testing/integration/utils.py index 61a501f..1383025 100644 --- a/testing/integration/utils.py +++ b/testing/integration/utils.py @@ -1,5 +1,6 @@ """Utilities for ValkeySearch testing.""" +from abc import abstractmethod import fcntl import logging import os @@ -14,6 +15,21 @@ import valkey.exceptions +class ValkeyServerUnderTest: + def __init__(self, process_handle: subprocess.Popen[Any], port: int): + self.process_handle = process_handle + self.port = port + + def terminate(self): + self.process_handle.terminate() + + def terminated(self): + return self.process_handle.poll() + + def ping(self) -> Any: + return valkey.Valkey(port=self.port).ping() + + def start_valkey_process( valkey_server_path: str, port: int, @@ -22,7 +38,7 @@ def start_valkey_process( args: dict[str, str], modules: dict[str, str], password: str | None = None, -) -> subprocess.Popen[Any]: +) -> ValkeyServerUnderTest: command = f"{valkey_server_path} --port {port} --dir {directory}" modules_args = [f'"--loadmodule {k} {v}"' for k, v in modules.items()] args_str = " ".join([f"--{k} {v}" for k, v in args.items()] + modules_args) @@ -61,7 +77,29 @@ def start_valkey_process( ) logging.info("Attempting to connect to Valkey: OK") - return process + return ValkeyServerUnderTest(process, port) + + +class ValkeyClusterUnderTest: + def __init__(self, servers: List[ValkeyServerUnderTest]): + self.servers = servers + + def terminate(self): + for server in self.servers: + server.terminate() + + def get_terminated_servers(self) -> List[int]: + result = [] + for server in self.servers: + if server.terminated(): + result.append(server.port) + return result + + def ping_all(self): + result = [] + for server in self.servers: + result.append(server.ping()) + return result def start_valkey_cluster( @@ -91,7 +129,7 @@ def start_valkey_cluster( Dictionary of port to valkey process. """ cluster_args = dict(args) - processes = {} + processes = [] for port in ports: stdout_path = os.path.join(stdout_directory, f"{port}_stdout.txt") @@ -103,7 +141,7 @@ def start_valkey_cluster( ) cluster_args["cluster-node-timeout"] = "10000" os.mkdir(node_dir) - processes[port] = start_valkey_process( + processes.append(start_valkey_process( valkey_server_path, port, node_dir, @@ -111,7 +149,7 @@ def start_valkey_cluster( cluster_args, modules, password, - ) + )) cli_stdout_path = os.path.join(stdout_directory, "valkey_cli_stdout.txt") cli_stdout_file = open(cli_stdout_path, "w") @@ -144,20 +182,105 @@ def start_valkey_cluster( # too early, even after checking with ping. time.sleep(10) - return processes + return ValkeyClusterUnderTest(processes) + + +class AttributeDefinition: + @abstractmethod + def to_arguments(self) -> List[Any]: + pass -def create_hnsw_index( - r: valkey.ValkeyCluster, +class HNSWVectorDefinition(AttributeDefinition): + def __init__( + self, + vector_dimensions: int, + m=10, + vector_type="FLOAT32", + distance_metric="COSINE", + ef_construction=5, + ef_runtime=10, + ): + self.vector_dimensions = vector_dimensions + self.m = m + self.vector_type = vector_type + self.distance_metric = distance_metric + self.ef_construction = ef_construction + self.ef_runtime = ef_runtime + + def to_arguments(self) -> List[Any]: + return [ + "VECTOR", + "HNSW", + 12, + "M", + self.m, + "TYPE", + self.vector_type, + "DIM", + self.vector_dimensions, + "DISTANCE_METRIC", + self.distance_metric, + "EF_CONSTRUCTION", + self.ef_construction, + "EF_RUNTIME", + self.ef_runtime, + ] + + +class FlatVectorDefinition(AttributeDefinition): + def __init__( + self, + vector_dimensions: int, + vector_type="FLOAT32", + distance_metric="COSINE", + ): + self.vector_dimensions = vector_dimensions + self.vector_type = vector_type + self.distance_metric = distance_metric + + def to_arguments(self) -> List[Any]: + return [ + "FLAT", + "6", + "TYPE", + self.vector_type, + "DIM", + self.vector_dimensions, + "DISTANCE_METRIC", + self.distance_metric, + ] + + +class TagDefinition(AttributeDefinition): + def __init__(self, separator=","): + self.separator = separator + + def to_arguments(self) -> List[Any]: + return [ + "TAG", + "SEPARATOR", + self.separator, + ] + + +class NumericDefinition(AttributeDefinition): + def to_arguments(self) -> List[Any]: + return [ + "NUMERIC", + ] + + +def create_index( + client: valkey.ValkeyCluster, index_name: str, - vector_dimensions: int, - vector_attribute_name="embedding", + attributes: Dict[str, AttributeDefinition], target_nodes=valkey.ValkeyCluster.DEFAULT_NODE, ): """Creates a new HNSW index. Args: - r: + client: index_name: vector_dimensions: """ @@ -165,87 +288,28 @@ def create_hnsw_index( "FT.CREATE", index_name, "SCHEMA", - vector_attribute_name, - "VECTOR", - "HNSW", - "12", # number of remaining arguments - "M", - 100, - "TYPE", - "FLOAT32", - "DIM", - vector_dimensions, - "DISTANCE_METRIC", - "COSINE", - "EF_CONSTRUCTION", - 5, - "EF_RUNTIME", - 10, - "tag", - "TAG", - "SEPARATOR", - ",", - "numeric", - "NUMERIC", - # "INITIAL_CAP", - # 15000, ] - return r.execute_command(*args, target_nodes=target_nodes) + for name, definition in attributes.items(): + args.append(name) + args.extend(definition.to_arguments()) + return client.execute_command(*args, target_nodes=target_nodes) -def create_flat_index( - r: valkey.ValkeyCluster, index_name: str, vector_dimensions: int -): - """Creates a new FLAT index. - Args: - r: - index_name: - vector_dimensions: - """ - args = [ - "FT.CREATE", - index_name, - "SCHEMA", - "embedding", - "VECTOR", - "FLAT", - "6", # number of remaining arguments - "TYPE", - "FLOAT32", - "DIM", - vector_dimensions, - "DISTANCE_METRIC", - "COSINE", - "tag", - "TAG", - "SEPARATOR", - ",", - "numeric", - "NUMERIC", - ] - r.execute_command(*args) - - -def drop_index(r: valkey.ValkeyCluster, index_name: str): +def drop_index(client: valkey.ValkeyCluster, index_name: str): args = [ "FT.DROPINDEX", index_name, ] - r.execute_command(*args) + client.execute_command(*args) -def fetch_ft_info(r: valkey.ValkeyCluster, index_name: str): +def fetch_ft_info(client: valkey.ValkeyCluster, index_name: str): args = [ "FT.INFO", index_name, ] - return r.execute_command(*args, target_nodes=r.ALL_NODES) - - -def flushdb(r: valkey.ValkeyCluster): - args = ["FLUSHDB", "SYNC"] - r.execute_command(*args) + return client.execute_command(*args, target_nodes=client.ALL_NODES) def generate_deterministic_data(vector_dimensions: int, seed: int): @@ -257,10 +321,10 @@ def generate_deterministic_data(vector_dimensions: int, seed: int): def insert_vector( - r: valkey.ValkeyCluster, key: str, vector_dimensions: int, seed: int + client: valkey.ValkeyCluster, key: str, vector_dimensions: int, seed: int ): vector = generate_deterministic_data(vector_dimensions, seed) - return r.hset( + return client.hset( key, { "embedding": vector, @@ -277,10 +341,10 @@ def insert_vectors_thread( port: int, seed: int, ): - r = valkey.Valkey(host=host, port=port) + client = valkey.Valkey(host=host, port=port) for i in range(1, num_vectors): insert_vector( - r=r, + client=client, key=(key_prefix + "_" + str(seed) + "_" + str(i)), vector_dimensions=vector_dimensions, seed=(i + seed * num_vectors), @@ -322,17 +386,20 @@ def insert_vectors( return threads -def delete_vector(r: valkey.ValkeyCluster, key: str): - return r.delete(key) +def delete_vector(client: valkey.ValkeyCluster, key: str): + return client.delete(key) def knn_search( - r: valkey.ValkeyCluster, index_name: str, vector_dimensions: int, seed: int + client: valkey.ValkeyCluster, + index_name: str, + vector_dimensions: int, + seed: int, ): """KNN searches the index. Args: - r: + client: index_name: vector_dimensions: seed: @@ -351,11 +418,11 @@ def knn_search( "DIALECT", 2, ] - return r.execute_command(*args, target_nodes=r.RANDOM) + return client.execute_command(*args, target_nodes=client.RANDOM) -def writer_queue_size(r: valkey.ValkeyCluster, index_name: str): - out = fetch_ft_info(r, index_name) +def writer_queue_size(client: valkey.ValkeyCluster, index_name: str): + out = fetch_ft_info(client, index_name) for index, item in enumerate(out): if "mutation_queue_size" in str(item): return int(str(out[index + 1])[2:-1]) @@ -364,19 +431,19 @@ def writer_queue_size(r: valkey.ValkeyCluster, index_name: str): def wait_for_empty_writer_queue_size( - r: valkey.ValkeyCluster, index_name: str, timeout=0 + client: valkey.ValkeyCluster, index_name: str, timeout=0 ): """Wait for the writer queue size to hit zero. Args: - r: + client: index_name: timeout: """ start = time.time() while True: try: - queue_size = writer_queue_size(r=r, index_name=index_name) + queue_size = writer_queue_size(client=client, index_name=index_name) if queue_size == 0: return logging.info( @@ -459,11 +526,11 @@ def loop(self): def periodic_bgsave_task( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, ) -> bool: try: logging.info(" Invoking background save") - r.bgsave(target_nodes=r.ALL_NODES) + client.bgsave(target_nodes=client.ALL_NODES) except ( valkey.exceptions.ConnectionError, valkey.exceptions.ResponseError, @@ -474,12 +541,12 @@ def periodic_bgsave_task( def periodic_bgsave( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, interval_sec: int, randomize: bool, ) -> RandomIntervalTask: thread = RandomIntervalTask( - "BGSAVE", interval_sec, randomize, lambda: periodic_bgsave_task(r) + "BGSAVE", interval_sec, randomize, lambda: periodic_bgsave_task(client) ) thread.run() return thread @@ -493,14 +560,14 @@ def __init__(self, index_lock: threading.Lock, ft_created: bool): def periodic_ftdrop_task( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, index_name: str, index_state: IndexState, ) -> bool: with index_state.index_lock: logging.info(" Invoking index drop") try: - drop_index(r, index_name) + drop_index(client, index_name) index_state.ft_created = False except ( valkey.exceptions.ConnectionError, @@ -515,7 +582,7 @@ def periodic_ftdrop_task( def periodic_ftdrop( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, interval_sec: int, random_interval: bool, index_name: str, @@ -525,26 +592,25 @@ def periodic_ftdrop( "FT.DROPINDEX", interval_sec, random_interval, - lambda: periodic_ftdrop_task(r, index_name, index_state), + lambda: periodic_ftdrop_task(client, index_name, index_state), ) thread.run() return thread def periodic_ftcreate_task( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, index_name: str, dimensions: int, - hnsw: bool, + attributes: Dict[str, AttributeDefinition], index_state: IndexState, ) -> bool: with index_state.index_lock: try: logging.info(" Invoking index creation") - if hnsw: - create_hnsw_index(r, index_name, dimensions) - else: - create_flat_index(r, index_name, dimensions) + create_index( + client=client, index_name=index_name, attributes=attributes + ) index_state.ft_created = True except ( valkey.exceptions.ConnectionError, @@ -559,12 +625,12 @@ def periodic_ftcreate_task( def periodic_ftcreate( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, interval_sec: int, random_interval: bool, index_name: str, dimensions: int, - hnsw: bool, + attributes: Dict[str, AttributeDefinition], index_state: IndexState, ) -> RandomIntervalTask: thread = RandomIntervalTask( @@ -572,7 +638,7 @@ def periodic_ftcreate( interval_sec, random_interval, lambda: periodic_ftcreate_task( - r, index_name, dimensions, hnsw, index_state + client, index_name, dimensions, attributes, index_state ), ) thread.run() @@ -580,14 +646,14 @@ def periodic_ftcreate( def periodic_flushdb_task( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, index_state: IndexState, use_coordinator: bool, ) -> bool: with index_state.index_lock: logging.info(" Invoking flush DB") try: - flushdb(r) + client.flushdb() if not use_coordinator: index_state.ft_created = False except ( @@ -602,7 +668,7 @@ def periodic_flushdb_task( def periodic_flushdb( - r: valkey.ValkeyCluster, + client: valkey.ValkeyCluster, interval_sec: int, random_interval: bool, index_state: IndexState, @@ -612,7 +678,7 @@ def periodic_flushdb( "FLUSHDB", interval_sec, random_interval, - lambda: periodic_flushdb_task(r, index_state, use_coordinator), + lambda: periodic_flushdb_task(client, index_state, use_coordinator), ) thread.run() return thread diff --git a/testing/integration/vector_search_integration_test.py b/testing/integration/vector_search_integration_test.py index 0e034d1..49e8d64 100644 --- a/testing/integration/vector_search_integration_test.py +++ b/testing/integration/vector_search_integration_test.py @@ -1,7 +1,6 @@ import difflib import os import pprint -import sys import time from typing import Any, List @@ -114,7 +113,7 @@ def setUpClass(cls): "required" ) - cls.valkey_server = utils.start_valkey_cluster( + cls.valkey_cluster_under_test = utils.start_valkey_cluster( FLAGS.valkey_server_path, FLAGS.valkey_cli_path, [6379, 6380, 6381], @@ -153,37 +152,42 @@ def tearDown(self): self.valkey_conn.execute_command( "FLUSHDB", target_nodes=self.valkey_conn.ALL_NODES ) - for port, process in self.valkey_server.items(): - if process.poll(): - self.fail("a process died during test, port: %d", port) - try: - valkey.Valkey(port=port).ping() - except Exception as e: # pylint: disable=broad-except - self.fail(f"Failed to ping valkey on port {port}: {e}") - + terminated = self.valkey_cluster_under_test.get_terminated_servers() + if terminated: + self.fail(f"Valkey servers terminated during test, ports: {terminated}") + try: + self.valkey_cluster_under_test.ping_all() + except Exception as e: # pylint: disable=broad-except + self.fail(f"Failed to ping all servers in cluster: {e}") super().tearDown() def test_create_and_drop_index(self): self.assertEqual( b"OK", - utils.create_hnsw_index( + utils.create_index( self.valkey_conn, "test_index", - 100, - "embedding", + attributes={ + "embedding": utils.HNSWVectorDefinition( + vector_dimensions=100 + ) + }, target_nodes=valkey.ValkeyCluster.RANDOM, ), ) time.sleep(1) with self.assertRaises(valkey.exceptions.ResponseError) as e: - utils.create_hnsw_index( + utils.create_index( self.valkey_conn, "test_index", - 100, - "embedding", + attributes={ + "embedding": utils.HNSWVectorDefinition( + vector_dimensions=100 + ) + }, target_nodes=valkey.ValkeyCluster.RANDOM, - ) + ), self.assertEqual( "Index test_index already exists.", e.exception.args[0], @@ -433,11 +437,14 @@ def test_vector_search(self, config): dimensions = 100 self.assertEqual( b"OK", - utils.create_hnsw_index( + utils.create_index( self.valkey_conn, config["index_name"], - dimensions, - config["vector_attribute_name"], + attributes={ + config["vector_attribute_name"]: utils.HNSWVectorDefinition( + vector_dimensions=dimensions + ) + }, ), ) time.sleep(1)