diff --git a/examples/tair_vector.py b/examples/tair_vector.py index 15efd9e..809ed01 100644 --- a/examples/tair_vector.py +++ b/examples/tair_vector.py @@ -1,7 +1,9 @@ #!/usr/bin/env python from conf_examples import get_tair + from tair import ResponseError + # create an index # @param index_name the name of index # @param dims the dimension of vector @@ -13,12 +15,13 @@ def create_index(index_name: str, dims: str): "M": 32, "ef_construct": 200, } - #index_params the params of index - return tair.tvs_create_index(index_name, dims,**index_params) + # index_params the params of index + return tair.tvs_create_index(index_name, dims, **index_params) except ResponseError as e: print(e) return None + # delete an index # @param index_name the name of index # @return success: True, fail: False. @@ -32,5 +35,5 @@ def delete_index(index_name: str): if __name__ == "__main__": - create_index("test",4) - delete_index("test") \ No newline at end of file + create_index("test", 4) + delete_index("test") diff --git a/tair/__init__.py b/tair/__init__.py index 0b1a145..550bc1d 100644 --- a/tair/__init__.py +++ b/tair/__init__.py @@ -22,8 +22,8 @@ from tair.tairsearch import ScandocidResult from tair.tairstring import ExcasResult, ExgetResult from tair.tairts import Aggregation, TairTsSkeyItem +from tair.tairvector import TairVectorIndex, TairVectorScanResult from tair.tairzset import TairZsetItem -from tair.tairvector import TairVectorScanResult, TairVectorIndex __all__ = [ "Aggregation", diff --git a/tair/commands.py b/tair/commands.py index 071f00b..518a01c 100644 --- a/tair/commands.py +++ b/tair/commands.py @@ -26,15 +26,15 @@ parse_exset, ) from tair.tairts import TairTsCommands -from tair.tairzset import TairZsetCommands, parse_tair_zset_items from tair.tairvector import ( TairVectorCommands, parse_tvs_get_index_result, parse_tvs_get_result, - parse_tvs_search_result, - parse_tvs_msearch_result, parse_tvs_hmget_result, + parse_tvs_msearch_result, + parse_tvs_search_result, ) +from tair.tairzset import TairZsetCommands, parse_tair_zset_items class TairCommands( @@ -143,16 +143,18 @@ def bool_ok(resp) -> bool: float(resp[0].decode()), float(resp[1].decode()) ), # TairVector - "TVS.CREATEINDEX":bool_ok, + "TVS.CREATEINDEX": bool_ok, "TVS.GETINDEX": parse_tvs_get_index_result, "TVS.DELINDEX": int_or_none, "TVS.HSET": int_or_none, "TVS.DEL": int_or_none, "TVS.HDEL": int_or_none, "TVS.HGETALL": parse_tvs_get_result, - "TVS.HMGET":parse_tvs_hmget_result, + "TVS.HMGET": parse_tvs_hmget_result, "TVS.KNNSEARCH": parse_tvs_search_result, "TVS.MKNNSEARCH": parse_tvs_msearch_result, + "TVS.MINDEXKNNSEARCH": parse_tvs_search_result, + "TVS.MINDEXMKNNSEARCH": parse_tvs_msearch_result, } diff --git a/tair/tairvector.py b/tair/tairvector.py index fc9d45e..2a23cb3 100644 --- a/tair/tairvector.py +++ b/tair/tairvector.py @@ -1,30 +1,35 @@ -from typing import Sequence, Tuple, Union,Iterable -from tair.typing import ResponseT -from typing import Dict, List, Tuple, Union +from functools import partial, reduce +from typing import Dict, Iterable, List, Sequence, Tuple, Union + from redis.client import pairs_to_dict from redis.utils import str_if_bytes -from typing import Sequence, Union -from functools import partial, reduce + +from tair.typing import ResponseT VectorType = Sequence[Union[int, float]] + class DistanceMetric: Euclidean = "L2" # an alias to L2 L2 = "L2" InnerProduct = "IP" Jaccard = "JACCARD" + class IndexType: HNSW = "HNSW" FLAT = "FLAT" + class Constants: VECTOR_KEY = "VECTOR" + class DataType: Float32 = "FLOAT32" Binary = "BINARY" + class TextVectorEncoder: SEP = bytes(",", "ascii") BITS = ("0", "1") @@ -50,6 +55,7 @@ def decode(buf: bytes) -> Tuple[float]: return tuple(int(x) for x in components) return tuple(float(x) for x in components) + class TairVectorScanResult: """ wrapper for the results of scan commands @@ -152,7 +158,7 @@ def tvs_mknnsearch( self, k: int, vectors: Sequence[VectorType], filter_str: str = None, **kwargs ): """batch approximate nearest neighbors search for a list of vectors""" - return self.client.tvs_knnsearch( + return self.client.tvs_mknnsearch( self.name, k, vectors, self.is_binary, filter_str, **kwargs ) @@ -164,7 +170,6 @@ def __repr__(self): class TairVectorCommands: - encode_vector = TextVectorEncoder.encode decode_vector = TextVectorEncoder.decode @@ -304,6 +309,8 @@ def tvs_scan(self, index: str, pattern: str = None, batch: int = 10): SEARCH_CMD = "TVS.KNNSEARCH" MSEARCH_CMD = "TVS.MKNNSEARCH" + MINDEXKNNSEARCH_CMD = "TVS.MINDEXKNNSEARCH" + MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH" def tvs_knnsearch( self, @@ -361,11 +368,73 @@ def tvs_mknnsearch( *params ) + def tvs_mindexknnsearch( + self, + index: Sequence[str], + k: int, + vector: Union[VectorType, str], + is_binary: bool = False, + filter_str: str = None, + **kwargs + ): + """ + search for the top @k approximate nearest neighbors of @vector in indexs + """ + params = reduce(lambda x, y: x + y, kwargs.items(), ()) + if not isinstance(vector, str): + vector = TairVectorCommands.encode_vector(vector, is_binary) + if filter_str is None: + return self.execute_command( + self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, *params + ) + return self.execute_command( + self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, filter_str, *params + ) + + def tvs_mindexmknnsearch( + self, + index: Sequence[str], + k: int, + vectors: Sequence[VectorType], + is_binary: bool = False, + filter_str: str = None, + **kwargs + ): + """ + batch approximate nearest neighbors search for a list of vectors + """ + params = reduce(lambda x, y: x + y, kwargs.items(), ()) + encoded_vectors = [ + TairVectorCommands.encode_vector(x, is_binary) for x in vectors + ] + if filter_str is None: + return self.execute_command( + self.MINDEXMKNNSEARCH_CMD, + len(index), + *index, + k, + len(encoded_vectors), + *encoded_vectors, + *params + ) + return self.execute_command( + self.MINDEXMKNNSEARCH_CMD, + len(index), + *index, + k, + len(encoded_vectors), + *encoded_vectors, + filter_str, + *params + ) + + def parse_tvs_get_index_result(resp) -> Union[Dict, None]: if len(resp) == 0: return None return pairs_to_dict(resp, decode_keys=True, decode_string_values=True) + def parse_tvs_get_result(resp) -> Dict: result = pairs_to_dict(resp, decode_keys=True, decode_string_values=False) @@ -376,13 +445,16 @@ def parse_tvs_get_result(resp) -> Dict: values = map(str_if_bytes, result.values()) return dict(zip(result.keys(), values)) + def parse_tvs_hmget_result(resp) -> tuple: if len(resp) == 0: return None - return ([resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))]) + return [resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))] + def parse_tvs_search_result(resp) -> List[Tuple]: return [(resp[i], float(resp[i + 1])) for i in range(0, len(resp), 2)] + def parse_tvs_msearch_result(resp) -> List[List[Tuple]]: - return [parse_tvs_search_result(r) for r in resp] \ No newline at end of file + return [parse_tvs_search_result(r) for r in resp] diff --git a/tests/test_from_url.py b/tests/test_from_url.py index 9950352..3189d61 100644 --- a/tests/test_from_url.py +++ b/tests/test_from_url.py @@ -66,18 +66,18 @@ async def test_from_url_async(): await t.close() -# @pytest.mark.asyncio -# async def test_from_url_async_cluster(): -# url = f"{TAIR_CLUSTER_SCHEME}://{TAIR_CLUSTER_HOST}:{TAIR_CLUSTER_PORT}" -# tc = AsyncTairCluster.from_url( -# url, username=TAIR_CLUSTER_USERNAME, password=TAIR_CLUSTER_PASSWORD -# ) -# key = "key_" + str(uuid.uuid4()) -# value = "value_" + str(uuid.uuid4()) -# -# assert await tc.exset(key, value) -# result: ExgetResult = await tc.exget(key) -# assert result.value == value.encode() -# assert result.version == 1 -# -# await tc.close() +@pytest.mark.asyncio +async def test_from_url_async_cluster(): + url = f"{TAIR_CLUSTER_SCHEME}://{TAIR_CLUSTER_HOST}:{TAIR_CLUSTER_PORT}" + tc = AsyncTairCluster.from_url( + url, username=TAIR_CLUSTER_USERNAME, password=TAIR_CLUSTER_PASSWORD + ) + key = "key_" + str(uuid.uuid4()) + value = "value_" + str(uuid.uuid4()) + + assert await tc.exset(key, value) + result: ExgetResult = await tc.exget(key) + assert result.value == value.encode() + assert result.version == 1 + + await tc.close() diff --git a/tests/test_tairvector.py b/tests/test_tairvector.py index 78aff23..fdef61a 100644 --- a/tests/test_tairvector.py +++ b/tests/test_tairvector.py @@ -1,15 +1,23 @@ # /user/bin/env python3 -import sys import os +import string +import sys +import unittest import uuid -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from random import choice, randint, random -from random import random, randint, choice -import unittest import redis -import string -from tair.tairvector import DataType, DistanceMetric, Constants, TairVectorIndex +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from tair.tairvector import ( + Constants, + DataType, + DistanceMetric, + TairVectorCommands, + TairVectorIndex, +) + from .conftest import get_tair_client client = get_tair_client() @@ -17,6 +25,7 @@ dim = 16 num_vectors = 100 test_vectors = [[random() for _ in range(dim)] for _ in range(num_vectors)] +test2_vectors = [[random() for _ in range(dim)] for _ in range(num_vectors)] num_attrs = 3 attr_keys = ["key-%d" % i for i in range(num_attrs)] attr_values = [ @@ -24,7 +33,7 @@ for _ in range(num_vectors * num_attrs) ] test_attributes = [ - dict(zip(attr_keys, attr_values[i: i + 3])) + dict(zip(attr_keys, attr_values[i : i + 3])) for i in range(0, num_vectors * num_attrs, num_attrs) ] @@ -177,7 +186,9 @@ def test_hmget(self): value2 = "value_" + str(uuid.uuid4()) ret = client.tvs_hset("test", key, vector=vector, field1=value1, field2=value2) self.assertTrue(ret) - obj = client.tvs_hmget("test", key, Constants.VECTOR_KEY, "field1", "field2", "field3") + obj = client.tvs_hmget( + "test", key, Constants.VECTOR_KEY, "field1", "field2", "field3" + ) self.assertEqual(len(obj[0].split(",")), len(vector)) self.assertEqual(obj[1], str(value1)) self.assertEqual(obj[2], str(value2)) @@ -230,17 +241,22 @@ def test_0_init(self): # delete test index try: client.tvs_del_index("test") + client.tvs_del_index("test2") except: pass ret = client.tvs_create_index("test", dim, **self.index_params) + ret = client.tvs_create_index("test2", dim, **self.index_params) if not ret: - raise RuntimeError("create test index failed") + raise RuntimeError("create test/test2 index failed") def test_1_insert_vectors(self): for i, v in enumerate(test_vectors): ret = client.tvs_hset("test", str(i).zfill(6), vector=v) self.assertTrue(ret) + for i, v in enumerate(test2_vectors): + ret = client.tvs_hset("test2", str(i).zfill(6), vector=v) + self.assertTrue(ret) def test_2_knn_search(self): for q in queries: @@ -280,8 +296,30 @@ def test_5_search_with_filters(self): def test_6_msearch_with_filters(self): pass + def test_7_mindexknnsearch(self): + indexs = ["test", "test2"] + for q in queries: + result = client.tvs_mindexknnsearch(indexs, self.top_k, vector=q) + self.assertEqual(self.top_k, len(result)) + d = 0.0 + for k, v in result: + self.assertGreaterEqual(v, d) + d = v + + def test_8_mindexmknnsearch(self): + indexs = ["test", "test2"] + batch = queries[:2] + result = client.tvs_mindexmknnsearch(indexs, self.top_k, batch) + self.assertEqual(len(result), len(batch)) + for r in result: + d = 0.0 + for _, v in r: + self.assertGreaterEqual(v, d) + d = v + def test_9_delete_index(self): client.tvs_del_index("test") + client.tvs_del_index("test2") class IndexApiTest(unittest.TestCase):