From 282c0d4d7fd582ac12566be8f40c61c867078cde Mon Sep 17 00:00:00 2001 From: "caoduanxin.cdx" Date: Mon, 28 Nov 2022 10:22:37 +0800 Subject: [PATCH] feat: add new data structure TairVector --- README.md | 1 + README.zh_CN.md | 1 + examples/tair_vector.py | 36 ++++ tair/__init__.py | 3 + tair/commands.py | 21 ++ tair/tairvector.py | 388 +++++++++++++++++++++++++++++++++ tests/test_tairvector.py | 452 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 902 insertions(+) create mode 100644 examples/tair_vector.py create mode 100644 tair/tairvector.py create mode 100644 tests/test_tairvector.py diff --git a/README.md b/README.md index 5e1f048..03d7b37 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ tair-py is a Python client of [Tair](https://www.alibabacloud.com/help/en/apsara - [TairGis](https://www.alibabacloud.com/help/en/apsaradb-for-redis/latest/tairgis-commands), allowing you to query points, linestrings, and polygons. (Coming soon) - [TairTs](https://www.alibabacloud.com/help/en/apsaradb-for-redis/latest/tairts-commands), is a time series data structure that is developed on top of Redis modules. (Coming soon) - [TairCpc](https://www.alibabacloud.com/help/en/apsaradb-for-redis/latest/taircpc-commands), is a data structure developed based on the compressed probability counting (CPC) sketch. (Coming soon) +- [TairVector](https://www.alibabacloud.com/help/en/apsaradb-for-redis/latest/tairvector), is a self-developed data structure that provides high-performance real-time storage and retrieval of vectors. (Coming soon) ## Install diff --git a/README.zh_CN.md b/README.zh_CN.md index 912ce26..ee7bb4c 100644 --- a/README.zh_CN.md +++ b/README.zh_CN.md @@ -20,6 +20,7 @@ - [TairDoc](https://help.aliyun.com/document_detail/145940.html), 支持存储`JSON`类型。(待开源) - [TairTs](https://help.aliyun.com/document_detail/408954.html), 时序数据结构,提供低时延、高并发的内存读写访问。(待开源) - [TairCpc](https://help.aliyun.com/document_detail/410587.html), 基于CPC(Compressed Probability Counting)压缩算法开发的数据结构,支持仅占用很小的内存空间对采样数据进行高性能计算。(待开源) +- [TairVector](https://help.aliyun.com/document_detail/457193.html),提供高性能、实时,集存储、检索于一体的向量数据库服务。(待开源) ## 安装 diff --git a/examples/tair_vector.py b/examples/tair_vector.py new file mode 100644 index 0000000..15efd9e --- /dev/null +++ b/examples/tair_vector.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +from conf_examples import get_tair +from tair import ResponseError + +# create an index +# @param index_name the name of index +# @param dims the dimension of vector +# @return success: True, fail: False. +def create_index(index_name: str, dims: str): + try: + tair = get_tair() + index_params = { + "M": 32, + "ef_construct": 200, + } + #index_params the params of index + return tair.tvs_create_index(index_name, dims,**index_params) + except ResponseError as e: + print(e) + return None + +# delete an index +# @param index_name the name of index +# @return success: True, fail: False. +def delete_index(index_name: str): + try: + tair = get_tair() + return tair.tvs_del_index(index_name) + except ResponseError as e: + print(e) + return False + + +if __name__ == "__main__": + create_index("test",4) + delete_index("test") \ No newline at end of file diff --git a/tair/__init__.py b/tair/__init__.py index ea919d4..0b1a145 100644 --- a/tair/__init__.py +++ b/tair/__init__.py @@ -23,6 +23,7 @@ from tair.tairstring import ExcasResult, ExgetResult from tair.tairts import Aggregation, TairTsSkeyItem from tair.tairzset import TairZsetItem +from tair.tairvector import TairVectorScanResult, TairVectorIndex __all__ = [ "Aggregation", @@ -54,4 +55,6 @@ "TairError", "TimeoutError", "WatchError", + "TairVectorScanResult", + "TairVectorIndex", ] diff --git a/tair/commands.py b/tair/commands.py index ea8eb94..071f00b 100644 --- a/tair/commands.py +++ b/tair/commands.py @@ -2,6 +2,7 @@ from redis import Redis from redis.asyncio import Redis as AsyncRedis +from redis.client import bool_ok, int_or_none from tair.tairbloom import TairBloomCommands from tair.taircpc import CpcUpdate2judResult, TairCpcCommands @@ -26,6 +27,14 @@ ) from tair.tairts import TairTsCommands from tair.tairzset import TairZsetCommands, parse_tair_zset_items +from tair.tairvector import ( + TairVectorCommands, + parse_tvs_get_index_result, + parse_tvs_get_result, + parse_tvs_search_result, + parse_tvs_msearch_result, + parse_tvs_hmget_result, +) class TairCommands( @@ -39,6 +48,7 @@ class TairCommands( TairDocCommands, TairTsCommands, TairCpcCommands, + TairVectorCommands, ): pass @@ -132,6 +142,17 @@ def bool_ok(resp) -> bool: "CPC.ARRAY.UPDATE2JUD": lambda resp: CpcUpdate2judResult( float(resp[0].decode()), float(resp[1].decode()) ), + # TairVector + "TVS.CREATEINDEX":bool_ok, + "TVS.GETINDEX": parse_tvs_get_index_result, + "TVS.DELINDEX": int_or_none, + "TVS.HSET": int_or_none, + "TVS.DEL": int_or_none, + "TVS.HDEL": int_or_none, + "TVS.HGETALL": parse_tvs_get_result, + "TVS.HMGET":parse_tvs_hmget_result, + "TVS.KNNSEARCH": parse_tvs_search_result, + "TVS.MKNNSEARCH": parse_tvs_msearch_result, } diff --git a/tair/tairvector.py b/tair/tairvector.py new file mode 100644 index 0000000..fc9d45e --- /dev/null +++ b/tair/tairvector.py @@ -0,0 +1,388 @@ +from typing import Sequence, Tuple, Union,Iterable +from tair.typing import ResponseT +from typing import Dict, List, Tuple, Union +from redis.client import pairs_to_dict +from redis.utils import str_if_bytes +from typing import Sequence, Union +from functools import partial, reduce + +VectorType = Sequence[Union[int, float]] + +class DistanceMetric: + Euclidean = "L2" # an alias to L2 + L2 = "L2" + InnerProduct = "IP" + Jaccard = "JACCARD" + +class IndexType: + HNSW = "HNSW" + FLAT = "FLAT" + +class Constants: + VECTOR_KEY = "VECTOR" + +class DataType: + Float32 = "FLOAT32" + Binary = "BINARY" + +class TextVectorEncoder: + SEP = bytes(",", "ascii") + BITS = ("0", "1") + + def encode(vector: Sequence[Union[float, int]], is_binary=False) -> bytes: + s = "" + if is_binary: + s = "[" + ",".join([TextVectorEncoder.BITS[x] for x in vector]) + "]" + else: + s = "[" + ",".join(["%f" % x for x in vector]) + "]" + return bytes(s, encoding="ascii") # ascii is enough + + def decode(buf: bytes) -> Tuple[float]: + if buf[0] != ord("[") or buf[-1] != ord("]"): + raise ValueError("invalid text vector value") + is_int = True + components = buf[1:-1].split(TextVectorEncoder.SEP) + for x in components: + if not x.isdigit(): + is_int = False + + if is_int: + return tuple(int(x) for x in components) + return tuple(float(x) for x in components) + +class TairVectorScanResult: + """ + wrapper for the results of scan commands + """ + + def __init__(self, client, get_batch_func): + self.client = client + self.get_batch = get_batch_func + + def __iter__(self): + self.cursor = "0" + self.batch = [] + self.idx = 0 + return self + + def __next__(self): + if self.idx >= len(self.batch): + if self.cursor is None: + # iteration finished + raise StopIteration + + # fetching next batch from server + res = self.get_batch(self.cursor) + # server returns cursor "0" means no more data to scan + if res[0] == b"0": + self.cursor = None + else: + self.cursor = res[0] + self.batch = res[1] + self.idx = 0 + if self.idx >= len(self.batch): + # in case the first batch from server is empty + raise StopIteration + ret = self.batch[self.idx] + self.idx += 1 + return ret.decode("utf-8") + + def iter(self): + """ + create an iterator from the result + """ + return iter(self) + + +class TairVectorIndex: + def __init__(self, client, name, **index_params): + self.client = client + self.name = name + + # create new index + if len(index_params) > 0: + self.client.tvs_create_index(name, **index_params) + + self.params = self.client.tvs_get_index(name) + if self.params is None: + # not exist + raise ValueError("index not exist") + + self.is_binary = False + if self.params.get("data_type", None) == DataType.Binary: + self.is_binary = True + + # bind methods + for method in ( + "tvs_del", + "tvs_hdel", + "tvs_hgetall", + "tvs_hmget", + "tvs_scan", + ): + attr = getattr(TairVectorCommands, method) + if callable(attr): + setattr(self, method, partial(attr, self.client, self.name)) + + def get(self): + """get and update index info""" + self.params = self.client.tvs_get_index(self.name) + if self.params is None: + # not exist + raise ValueError("index not exist") + return self.params + + def tvs_hset(self, key: str, vector: Union[VectorType, str] = None, **kwargs): + """add/update a data entry to index + @key: key for the data entry + @vector: optional, vector value of the data entry + @kwargs: optional, attribute pairs for the data entry + """ + return self.client.tvs_hset(self.name, key, vector, self.is_binary, **kwargs) + + def tvs_knnsearch( + self, k: int, vector: Union[VectorType, str], filter_str: str = None, **kwargs + ): + """search for the top @k approximate nearest neighbors of @vector""" + return self.client.tvs_knnsearch( + self.name, k, vector, self.is_binary, filter_str, **kwargs + ) + + def tvs_mknnsearch( + self, k: int, vectors: Sequence[VectorType], filter_str: str = None, **kwargs + ): + """batch approximate nearest neighbors search for a list of vectors""" + return self.client.tvs_knnsearch( + self.name, k, vectors, self.is_binary, filter_str, **kwargs + ) + + def __str__(self): + return "%s[%s]" % (self.name, self.params) + + def __repr__(self): + return str(self) + + +class TairVectorCommands: + + encode_vector = TextVectorEncoder.encode + decode_vector = TextVectorEncoder.decode + + CREATE_INDEX_CMD = "TVS.CREATEINDEX" + GET_INDEX_CMD = "TVS.GETINDEX" + DEL_INDEX_CMD = "TVS.DELINDEX" + SCAN_INDEX_CMD = "TVS.SCANINDEX" + + def tvs_create_index( + self, + name: str, + dim: int, + distance_type: str = DistanceMetric.L2, + index_type: str = IndexType.HNSW, + data_type: str = DataType.Float32, + **kwargs + ): + """ + create a vector + @distance_type: distance metric type (L2/IP). + @index_type: type of the index (HNSW/FLAT). + keyword arguments + @ef_construct: efConstruct for HNSW index (available if index_type == HNSW). + @M: M for HNSW index (available if index_type == HNSW). + """ + params = reduce(lambda x, y: x + y, kwargs.items(), ()) + return self.execute_command( + self.CREATE_INDEX_CMD, + name, + dim, + index_type, + distance_type, + "data_type", + data_type, + *params + ) + + def tvs_get_index(self, name: str): + """ + get the infomation of an index + """ + return self.execute_command(self.GET_INDEX_CMD, name) + + def tvs_del_index(self, name: str): + """ + delete an index and all its data + """ + return self.execute_command(self.DEL_INDEX_CMD, name) + + def tvs_scan_index( + self, pattern: str = None, batch: int = 10 + ) -> TairVectorScanResult: + """ + scan all the indices + """ + args = ([] if pattern is None else ["MATCH", pattern]) + ["COUNT", batch] + get_batch = lambda c: self.execute_command(self.SCAN_INDEX_CMD, c, *args) + + return TairVectorScanResult(self, get_batch) + + def tvs_index(self, name: str, **index_params) -> TairVectorIndex: + """ + get or create an index + """ + return TairVectorIndex(self, name, **index_params) + + HSET_CMD = "TVS.HSET" + DEL_CMD = "TVS.DEL" + HDEL_CMD = "TVS.HDEL" + HGETALL_CMD = "TVS.HGETALL" + HMGET_CMD = "TVS.HMGET" + SCAN_CMD = "TVS.SCAN" + + def tvs_hset( + self, + index: str, + key: str, + vector: Union[VectorType, str] = None, + is_binary=False, + **kwargs + ): + """ + add/update a data entry to index + @index: index name + @key: key for the data entry + @vector: optional, vector value of the data entry + @is_binary: whether @vector is a binary vector + @kwargs: optional, attribute pairs for the data entry + """ + attributes = reduce(lambda x, y: x + y, kwargs.items(), ()) + if vector is None: + return self.execute_command(self.HSET_CMD, index, key, *attributes) + if not isinstance(vector, str): + vector = TairVectorCommands.encode_vector(vector, is_binary) + return self.execute_command( + self.HSET_CMD, index, key, Constants.VECTOR_KEY, vector, *attributes + ) + + def tvs_del(self, index: str, key: str): + """ + delete a data entry from index + """ + return self.execute_command(self.DEL_CMD, index, key) + + def tvs_hdel(self, index: str, key: str, *args): + """ + delete attribute pairs for a data entry + """ + if len(args) == 0: + # nothing to delete + return 0 + return self.execute_command(self.HDEL_CMD, index, key, *args) + + def tvs_hgetall(self, index: str, key: str): + """ + get the vector value(if any) and attributes(if any) for a data entry + """ + return self.execute_command(self.HGETALL_CMD, index, key) + + def tvs_hmget(self, index: str, key: str, *args): + """ + get specified attributes of a data entry, use attribute key "VECTOR" to get vector value + """ + return self.execute_command(self.HMGET_CMD, index, key, *args) + + # def tvs_hmget(self, index: str, key: str,fields: Iterable[str]): + # return self.execute_command(self.HMGET_CMD, index,key, *fields) + + def tvs_scan(self, index: str, pattern: str = None, batch: int = 10): + """ + scan all data entries in an index + """ + args = ([] if pattern is None else ["MATCH", pattern]) + ["COUNT", batch] + get_batch = lambda c: self.execute_command(self.SCAN_CMD, index, c, *args) + + return TairVectorScanResult(self, get_batch) + + SEARCH_CMD = "TVS.KNNSEARCH" + MSEARCH_CMD = "TVS.MKNNSEARCH" + + def tvs_knnsearch( + self, + index: str, + k: int, + vector: Union[VectorType, str], + is_binary: bool = False, + filter_str: str = None, + **kwargs + ): + """ + search for the top @k approximate nearest neighbors of @vector in an index + """ + params = reduce(lambda x, y: x + y, kwargs.items(), ()) + if not isinstance(vector, str): + vector = TairVectorCommands.encode_vector(vector, is_binary) + if filter_str is None: + return self.execute_command(self.SEARCH_CMD, index, k, vector, *params) + return self.execute_command( + self.SEARCH_CMD, index, k, vector, filter_str, *params + ) + + def tvs_mknnsearch( + self, + index: str, + k: int, + vectors: Sequence[VectorType], + is_binary: bool = False, + filter_str: str = None, + **kwargs + ): + """ + batch approximate nearest neighbors search for a list of vectors + """ + params = reduce(lambda x, y: x + y, kwargs.items(), ()) + encoded_vectors = [ + TairVectorCommands.encode_vector(x, is_binary) for x in vectors + ] + if filter_str is None: + return self.execute_command( + self.MSEARCH_CMD, + index, + k, + len(encoded_vectors), + *encoded_vectors, + *params + ) + return self.execute_command( + self.MSEARCH_CMD, + index, + k, + len(encoded_vectors), + *encoded_vectors, + filter_str, + *params + ) + +def parse_tvs_get_index_result(resp) -> Union[Dict, None]: + if len(resp) == 0: + return None + return pairs_to_dict(resp, decode_keys=True, decode_string_values=True) + +def parse_tvs_get_result(resp) -> Dict: + result = pairs_to_dict(resp, decode_keys=True, decode_string_values=False) + + if Constants.VECTOR_KEY in result: + result[Constants.VECTOR_KEY] = TextVectorEncoder.decode( + result[Constants.VECTOR_KEY] + ) + values = map(str_if_bytes, result.values()) + return dict(zip(result.keys(), values)) + +def parse_tvs_hmget_result(resp) -> tuple: + if len(resp) == 0: + return None + return ([resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))]) + +def parse_tvs_search_result(resp) -> List[Tuple]: + return [(resp[i], float(resp[i + 1])) for i in range(0, len(resp), 2)] + +def parse_tvs_msearch_result(resp) -> List[List[Tuple]]: + return [parse_tvs_search_result(r) for r in resp] \ No newline at end of file diff --git a/tests/test_tairvector.py b/tests/test_tairvector.py new file mode 100644 index 0000000..78aff23 --- /dev/null +++ b/tests/test_tairvector.py @@ -0,0 +1,452 @@ +# /user/bin/env python3 +import sys +import os +import uuid +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from random import random, randint, choice +import unittest +import redis +import string + +from tair.tairvector import DataType, DistanceMetric, Constants, TairVectorIndex +from .conftest import get_tair_client + +client = get_tair_client() + +dim = 16 +num_vectors = 100 +test_vectors = [[random() for _ in range(dim)] for _ in range(num_vectors)] +num_attrs = 3 +attr_keys = ["key-%d" % i for i in range(num_attrs)] +attr_values = [ + "".join(choice(string.ascii_uppercase + string.digits) for _ in range(4)) + for _ in range(num_vectors * num_attrs) +] +test_attributes = [ + dict(zip(attr_keys, attr_values[i: i + 3])) + for i in range(0, num_vectors * num_attrs, num_attrs) +] + +num_queries = 10 +queries = [[random() for _ in range(dim)] for _ in range(num_queries)] + + +class IndexCommandsTest(unittest.TestCase): + def __init__(self, methodName="runTest"): + super().__init__(methodName=methodName) + self.index_params = {"M": 32, "ef_construct": 200} + + # the following test functions will execute in alphabetical order + def test_0_delete_all_indices(self): + # get all indices + indices = [] + result = client.tvs_scan_index() + for index in result.iter(): + indices.append(index) + + print("deleting indices:", indices) + # deleting all indices + for index in indices: + try: + ret = client.tvs_del_index(index) + self.assertEqual(ret, 1) + except: + self.logger.error("delete index [%s] failed" % index) + + # scan indices again + indices = [] + result = client.tvs_scan_index() + for index in result.iter(): + indices.append(index) + self.assertEqual(indices, []) + + def test_1_delete_noexist_index(self): + print( + "deleting no-exist index", + ) + ret = client.tvs_del_index("test") + self.assertFalse(ret, 0) + + def test_2_get_nonexist_index(self): + self.assertIsNone(client.tvs_get_index("test")) + + def test_3_create_index(self): + ret = client.tvs_create_index("test", dim, **self.index_params) + self.assertEqual(ret, True) + + def test_4_create_duplicate_index(self): + with self.assertRaises(redis.exceptions.ResponseError): + ret = client.tvs_create_index("test", dim, **self.index_params) + + def test_5_get_index(self): + index = client.tvs_get_index("test") + for k, v in self.index_params.items(): + self.assertTrue(k in index) + self.assertEqual(index[k], str(v)) + + def test_6_scan_index(self): + indices = [] + result = client.tvs_scan_index() + for index in result.iter(): + indices.append(index) + self.assertListEqual(indices, ["test"]) + + def test_7_scan_index_with_pattern(self): + indices = [] + result = client.tvs_scan_index(pattern="aaa") + for index in result.iter(): + indices.append(index) + self.assertEqual(indices, []) + + def test_8_delete_test_index(self): + self.assertEqual(client.tvs_del_index("test"), 1) + + +def floatEqual(v1: float, v2: float, epsilon=1e-6) -> bool: + delta = v1 - v2 + return delta >= -epsilon and delta <= epsilon + + +def vectorEqual(v1, v2) -> bool: + if len(v1) != len(v2): + return False + for i in range(len(v1)): + if not floatEqual(v1[i], v2[i]): + return False + return True + + +class DataCommandsTest(unittest.TestCase): + def __init__(self, methodName="runTest"): + super().__init__(methodName=methodName) + self.index_params = { + "M": 32, + "ef_construct": 200, + } + + def test_0_init(self): + # delete test index + try: + client.tvs_del_index("test") + except: + pass + + def test_1_hset_with_no_vector(self): + self.assertTrue(client.tvs_create_index("test", dim, **self.index_params)) + for i, attrs in enumerate(test_attributes): + ret = client.tvs_hset("test", str(i), vector=None, **attrs) + self.assertEqual(ret, len(attrs)) + # get and check + for i, attrs in enumerate(test_attributes): + obj = client.tvs_hgetall("test", str(i)) + self.assertDictEqual(obj, attrs) + self.assertEqual(client.tvs_del_index("test"), 1) + + def test_2_hset_with_no_attributes(self): + self.assertTrue(client.tvs_create_index("test", dim, **self.index_params)) + for i, v in enumerate(test_vectors): + ret = client.tvs_hset("test", str(i), vector=v) + self.assertTrue(ret) + # get and check + for i, v in enumerate(test_vectors): + obj = client.tvs_hgetall("test", str(i)) + self.assertTrue(Constants.VECTOR_KEY in obj) + self.assertTrue(vectorEqual(v, obj[Constants.VECTOR_KEY])) + + self.assertEqual(client.tvs_del_index("test"), 1) + + def test_3_hset(self): + self.assertTrue(client.tvs_create_index("test", dim, **self.index_params)) + for i, v in enumerate(test_vectors): + ret = client.tvs_hset("test", str(i), vector=v, **test_attributes[i]) + self.assertTrue(ret) + # get and check + for i, v in enumerate(test_vectors): + obj = client.tvs_hgetall("test", str(i)) + self.assertTrue(Constants.VECTOR_KEY in obj) + self.assertTrue(vectorEqual(v, obj[Constants.VECTOR_KEY])) + del obj[Constants.VECTOR_KEY] + self.assertDictEqual(obj, test_attributes[i]) + + def test_hmget(self): + self.assertTrue(client.tvs_create_index("test", dim, **self.index_params)) + vector = [randint(1, 100) for _ in range(dim)] + key = "key_" + str(uuid.uuid4()) + value1 = "value_" + str(uuid.uuid4()) + value2 = "value_" + str(uuid.uuid4()) + ret = client.tvs_hset("test", key, vector=vector, field1=value1, field2=value2) + self.assertTrue(ret) + obj = client.tvs_hmget("test", key, Constants.VECTOR_KEY, "field1", "field2", "field3") + self.assertEqual(len(obj[0].split(",")), len(vector)) + self.assertEqual(obj[1], str(value1)) + self.assertEqual(obj[2], str(value2)) + + def test_4_scan(self): + result = client.tvs_scan("test") + scanned_keys = [] + for k in result.iter(): + scanned_keys.append(k) + expected_keys = [str(i) for i in range(len(test_vectors))] + self.assertSetEqual(set(scanned_keys), set(expected_keys)) + + def test_5_scan_with_pattern(self): + result = client.tvs_scan("test", pattern="aaa") + scanned_keys = [] + for k in result.iter(): + scanned_keys.append(k) + self.assertEqual(scanned_keys, []) + + result = client.tvs_scan("test", pattern="0") + scanned_keys = [] + for k in result.iter(): + scanned_keys.append(k) + self.assertEqual(scanned_keys, ["0"]) + + def test_6_hdel(self): + for i, attr in enumerate(test_attributes): + keys = attr.keys() + self.assertEqual(client.tvs_hdel("test", str(i), *keys), len(keys)) + + def test_7_delete(self): + # delete inserted entries + for i in range(len(test_vectors)): + self.assertEqual(client.tvs_del("test", str(i)), 1) + + def test_9_delete(self): + self.assertEqual(client.tvs_del_index("test"), 1) + + +class SearchCommandsTest(unittest.TestCase): + def __init__(self, methodName="runTest"): + super().__init__(methodName=methodName) + self.top_k = 10 + self.index_params = { + "M": 32, + "ef_construct": 200, + } + + def test_0_init(self): + # delete test index + try: + client.tvs_del_index("test") + except: + pass + + ret = client.tvs_create_index("test", dim, **self.index_params) + if not ret: + raise RuntimeError("create test index failed") + + def test_1_insert_vectors(self): + for i, v in enumerate(test_vectors): + ret = client.tvs_hset("test", str(i).zfill(6), vector=v) + self.assertTrue(ret) + + def test_2_knn_search(self): + for q in queries: + result = client.tvs_knnsearch("test", self.top_k, vector=q) + self.assertEqual(self.top_k, len(result)) + d = 0.0 + for k, v in result: + self.assertGreaterEqual(v, d) + d = v + + def test_3_knn_search_with_params(self): + for ef in range(self.top_k, 100, 10): + for q in queries: + result = client.tvs_knnsearch( + "test", self.top_k, vector=q, ef_search=ef + ) + self.assertEqual(self.top_k, len(result)) + d = 0.0 + for k, v in result: + self.assertGreaterEqual(v, d) + d = v + + def test_4_knn_msearch(self): + batch = queries[:2] + result = client.tvs_mknnsearch("test", self.top_k, batch) + self.assertEqual(len(result), len(batch)) + for r in result: + d = 0.0 + for _, v in r: + self.assertGreaterEqual(v, d) + d = v + + # todo + def test_5_search_with_filters(self): + pass + + def test_6_msearch_with_filters(self): + pass + + def test_9_delete_index(self): + client.tvs_del_index("test") + + +class IndexApiTest(unittest.TestCase): + def __init__(self, methodName="runTest"): + super().__init__(methodName=methodName) + self.index_params = { + "M": 32, + "ef_construct": 200, + } + + def test_0_init(self): + client.tvs_del_index("test") + # try getting a non-existing index + with self.assertRaises(ValueError): + index = TairVectorIndex(client, "test") + + def test_1_create_index(self): + # create a new index + index = TairVectorIndex(client, "test", dim=dim, **self.index_params) + self.assertIsNotNone(index) + print(index) + + # get an existing index + index2 = TairVectorIndex(client, "test") + self.assertEqual(str(index), str(index2)) + + def test_2_create_duplicate_index(self): + with self.assertRaises(redis.exceptions.ResponseError): + index = TairVectorIndex(client, "test", dim=dim, **self.index_params) + + def test_3_index_api(self): + index = client.tvs_index("test") + + for i, v in enumerate(test_vectors): + ret = index.tvs_hset(str(i), vector=v, **test_attributes[i]) + self.assertTrue(ret) + # get and check + for i, v in enumerate(test_vectors): + obj = index.tvs_hgetall(str(i)) + self.assertTrue(Constants.VECTOR_KEY in obj) + self.assertTrue(vectorEqual(v, obj[Constants.VECTOR_KEY])) + del obj[Constants.VECTOR_KEY] + self.assertDictEqual(obj, test_attributes[i]) + + result = index.tvs_scan() + scanned_keys = [] + for k in result.iter(): + scanned_keys.append(k) + expected_keys = [str(i) for i in range(len(test_vectors))] + self.assertSetEqual(set(scanned_keys), set(expected_keys)) + + for i, attr in enumerate(test_attributes): + keys = attr.keys() + self.assertEqual(index.tvs_hdel(str(i), *keys), len(keys)) + + # delete inserted entries + for i in range(len(test_vectors)): + self.assertEqual(index.tvs_del(str(i)), 1) + + +dim_bin_vector = 16 +num_bin_vectors = 1000 +test_bin_vectors = [ + [randint(0, 1) for _ in range(dim_bin_vector)] for _ in range(num_bin_vectors) +] + + +def jaccard(x, y): + intersect = 0 + union = 0 + for i in range(len(x)): + if x[i] == 1 or y[i] == 1: + union += 1 + if x[i] == 1 and y[i] == 1: + intersect += 1 + return 1 - intersect / union + + +class BinaryIndexTest(unittest.TestCase): + def __init__(self, methodName="runTest"): + super().__init__(methodName=methodName) + self.index_params = { + "M": 32, + "ef_construct": 200, + } + + def test_0_create(self): + ret = client.tvs_create_index( + "test_bin", + dim_bin_vector, + distance_type=DistanceMetric.Jaccard, + data_type=DataType.Binary, + **self.index_params + ) + self.assertTrue(ret) + + # try create index with invalid parameters + with self.assertRaises(redis.exceptions.ResponseError): + ret = client.tvs_create_index( + "test_bin2", + dim_bin_vector, + distance_type=DistanceMetric.L2, + data_type=DataType.Binary, + **self.index_params + ) + + def test_1_hset_hget(self): + for i, v in enumerate(test_bin_vectors): + ret = client.tvs_hset("test_bin", str(i), v, True) + self.assertEqual(ret, 1) + for i, v in enumerate(test_bin_vectors): + ret = client.tvs_hgetall("test_bin", str(i)) + self.assertEqual(v, list(ret[Constants.VECTOR_KEY])) + + def test_2_knnsearch(self): + q = [randint(0, 1) for _ in range(dim_bin_vector)] + ret = client.tvs_knnsearch("test_bin", 10, q, True) + last_dist = 0.0 + + for k, s in ret: + self.assertGreaterEqual(s, last_dist) + last_dist = s + self.assertAlmostEqual(s, jaccard(q, test_bin_vectors[int(k)])) + + def test_3_index_api(self): + ret = client.tvs_del_index("test_bin") + self.assertEqual(ret, 1) + with self.assertRaises(ValueError): + index = TairVectorIndex(client, "test_bin") + + index = TairVectorIndex( + client, + "test_bin", + dim=dim_bin_vector, + distance_type=DistanceMetric.Jaccard, + data_type=DataType.Binary, + **self.index_params + ) + self.assertIsNotNone(index) + + index2 = TairVectorIndex(client, "test_bin") + self.assertEqual(str(index), str(index2)) + + for i, v in enumerate(test_bin_vectors): + ret = index.tvs_hset(str(i), v) + self.assertEqual(ret, 1) + + for i, v in enumerate(test_bin_vectors): + ret = index.tvs_hgetall(str(i)) + self.assertEqual(v, list(ret[Constants.VECTOR_KEY])) + + q = [randint(0, 1) for _ in range(dim_bin_vector)] + ret = index.tvs_knnsearch(10, q) + last_dist = 0.0 + + for k, s in ret: + self.assertGreaterEqual(s, last_dist) + last_dist = s + self.assertAlmostEqual(s, jaccard(q, test_bin_vectors[int(k)])) + + def test_9_cleanup(self): + ret = client.tvs_del_index("test_bin") + self.assertEqual(ret, 1) + + +if __name__ == "__main__": + unittest.main() + client.close()