diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index a0b4bcccab..ce6226f98c 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559" }, "servers": [ { @@ -429,39 +429,6 @@ } } }, - "/models/delete": { - "post": { - "responses": { - "200": { - "description": "OK" - } - }, - "tags": [ - "Models" - ], - "parameters": [ - { - "name": "X-LlamaStack-ProviderData", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/DeleteModelRequest" - } - } - }, - "required": true - } - } - }, "/inference/embeddings": { "post": { "responses": { @@ -2259,18 +2226,44 @@ } } }, - "/models/update": { + "/memory_banks/unregister": { "post": { "responses": { "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Model" - } + "description": "OK" + } + }, + "tags": [ + "MemoryBanks" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnregisterMemoryBankRequest" } } + }, + "required": true + } + } + }, + "/models/unregister": { + "post": { + "responses": { + "200": { + "description": "OK" } }, "tags": [ @@ -2291,7 +2284,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateModelRequest" + "$ref": "#/components/schemas/UnregisterModelRequest" } } }, @@ -4622,18 +4615,6 @@ "session_id" ] }, - "DeleteModelRequest": { - "type": "object", - "properties": { - "model_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "model_id" - ] - }, "EmbeddingsRequest": { "type": "object", "properties": { @@ -7912,42 +7893,23 @@ ], "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." }, - "UpdateModelRequest": { + "UnregisterMemoryBankRequest": { "type": "object", "properties": { - "model_id": { - "type": "string" - }, - "provider_model_id": { + "memory_bank_id": { "type": "string" - }, - "provider_id": { + } + }, + "additionalProperties": false, + "required": [ + "memory_bank_id" + ] + }, + "UnregisterModelRequest": { + "type": "object", + "properties": { + "model_id": { "type": "string" - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } } }, "additionalProperties": false, @@ -8132,10 +8094,6 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, - { - "name": "DeleteModelRequest", - "description": "" - }, { "name": "DoraFinetuningConfig", "description": "" @@ -8563,12 +8521,16 @@ "description": "" }, { - "name": "UnstructuredLogEvent", - "description": "" + "name": "UnregisterMemoryBankRequest", + "description": "" }, { - "name": "UpdateModelRequest", - "description": "" + "name": "UnregisterModelRequest", + "description": "" + }, + { + "name": "UnstructuredLogEvent", + "description": "" }, { "name": "UserMessage", @@ -8657,7 +8619,6 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", - "DeleteModelRequest", "DoraFinetuningConfig", "EmbeddingsRequest", "EmbeddingsResponse", @@ -8754,8 +8715,9 @@ "TrainingConfig", "Turn", "URL", + "UnregisterMemoryBankRequest", + "UnregisterModelRequest", "UnstructuredLogEvent", - "UpdateModelRequest", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 2ca26f759b..a0b3d6c5eb 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -867,14 +867,6 @@ components: - agent_id - session_id type: object - DeleteModelRequest: - additionalProperties: false - properties: - model_id: - type: string - required: - - model_id - type: object DoraFinetuningConfig: additionalProperties: false properties: @@ -3244,6 +3236,22 @@ components: format: uri pattern: ^(https?://|file://|data:) type: string + UnregisterMemoryBankRequest: + additionalProperties: false + properties: + memory_bank_id: + type: string + required: + - memory_bank_id + type: object + UnregisterModelRequest: + additionalProperties: false + properties: + model_id: + type: string + required: + - model_id + type: object UnstructuredLogEvent: additionalProperties: false properties: @@ -3280,28 +3288,6 @@ components: - message - severity type: object - UpdateModelRequest: - additionalProperties: false - properties: - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - model_id: - type: string - provider_id: - type: string - provider_model_id: - type: string - required: - - model_id - type: object UserMessage: additionalProperties: false properties: @@ -3414,7 +3400,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325" + \ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4216,7 +4202,7 @@ paths: responses: {} tags: - MemoryBanks - /models/delete: + /memory_banks/unregister: post: parameters: - description: JSON-encoded provider data which will be made available to the @@ -4230,13 +4216,13 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DeleteModelRequest' + $ref: '#/components/schemas/UnregisterMemoryBankRequest' required: true responses: '200': description: OK tags: - - Models + - MemoryBanks /models/get: get: parameters: @@ -4307,7 +4293,7 @@ paths: description: OK tags: - Models - /models/update: + /models/unregister: post: parameters: - description: JSON-encoded provider data which will be made available to the @@ -4321,14 +4307,10 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UpdateModelRequest' + $ref: '#/components/schemas/UnregisterModelRequest' required: true responses: '200': - content: - application/json: - schema: - $ref: '#/components/schemas/Model' description: OK tags: - Models @@ -4960,9 +4942,6 @@ tags: - description: name: DeleteAgentsSessionRequest -- description: - name: DeleteModelRequest - description: name: DoraFinetuningConfig @@ -5257,12 +5236,15 @@ tags: name: Turn - description: name: URL +- description: + name: UnregisterMemoryBankRequest +- description: + name: UnregisterModelRequest - description: name: UnstructuredLogEvent -- description: - name: UpdateModelRequest - description: name: UserMessage - description: MemoryBank: ... + + @webmethod(route="/memory_banks/unregister", method="POST") + async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index aa63ca5412..34541b96ec 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -7,7 +7,7 @@ import asyncio import json -from typing import Any, Dict, List, Optional +from typing import List, Optional import fire import httpx @@ -61,28 +61,7 @@ async def get_model(self, identifier: str) -> Optional[Model]: return None return Model(**j) - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: - async with httpx.AsyncClient() as client: - response = await client.put( - f"{self.base_url}/models/update", - json={ - "model_id": model_id, - "provider_model_id": provider_model_id, - "provider_id": provider_id, - "metadata": metadata, - }, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - return Model(**response.json()) - - async def delete_model(self, model_id: str) -> None: + async def unregister_model(self, model_id: str) -> None: async with httpx.AsyncClient() as client: response = await client.delete( f"{self.base_url}/models/delete", diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 5ffcde52f7..a1bfcac002 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -55,14 +55,5 @@ async def register_model( metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... - @webmethod(route="/models/update", method="POST") - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: ... - - @webmethod(route="/models/delete", method="POST") - async def delete_model(self, model_id: str) -> None: ... + @webmethod(route="/models/unregister", method="POST") + async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index a940dbae6f..76078e6529 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -51,6 +51,16 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable raise ValueError(f"Unknown API {api} for registering object with provider") +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.memory: + return await p.unregister_memory_bank(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + Registry = Dict[str, List[RoutableObjectWithProvider]] @@ -148,17 +158,11 @@ async def get_object_by_identifier( return obj - async def delete_object(self, obj: RoutableObjectWithProvider) -> None: + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - # TODO: delete from provider - - async def update_object( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: - registered_obj = await register_object_with_provider( + await unregister_object_from_provider( obj, self.impls_by_provider_id[obj.provider_id] ) - return await self.dist_registry.update(registered_obj) async def register_object( self, obj: RoutableObjectWithProvider @@ -232,32 +236,11 @@ async def register_model( registered_model = await self.register_object(model) return registered_model - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - - updated_model = Model( - identifier=model_id, - provider_resource_id=provider_model_id - or existing_model.provider_resource_id, - provider_id=provider_id or existing_model.provider_id, - metadata=metadata or existing_model.metadata, - ) - registered_model = await self.update_object(updated_model) - return registered_model - - async def delete_model(self, model_id: str) -> None: + async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: raise ValueError(f"Model {model_id} not found") - await self.delete_object(existing_model) + await self.unregister_object(existing_model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): @@ -333,6 +316,12 @@ async def register_memory_bank( await self.register_object(memory_bank) return memory_bank + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + existing_bank = await self.get_memory_bank(memory_bank_id) + if existing_bank is None: + raise ValueError(f"Memory bank {memory_bank_id} not found") + await self.unregister_object(existing_bank) + class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[Dataset]: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 51ff163abe..080204e450 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -45,6 +45,8 @@ class Api(Enum): class ModelsProtocolPrivate(Protocol): async def register_model(self, model: Model) -> None: ... + async def unregister_model(self, model_id: str) -> None: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... @@ -55,6 +57,8 @@ async def list_memory_banks(self) -> List[MemoryBank]: ... async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... + async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... + class DatasetsProtocolPrivate(Protocol): async def register_dataset(self, dataset: Dataset) -> None: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 4f5c0c8c21..e6bcd6730d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,6 +71,9 @@ def check_model(self, request) -> None: f"Model mismatch: {request.model} != {self.model.descriptor()}" ) + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 8869cc07ff..0e7ba872c9 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -108,6 +108,9 @@ def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParam return VLLMSamplingParams(**kwargs) + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 0790eb67df..92235ea893 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import json import logging from typing import Any, Dict, List, Optional @@ -37,10 +39,52 @@ class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] chunk_by_index: Dict[int, str] - def __init__(self, dimension: int): + def __init__(self, dimension: int, kvstore=None, bank_id: str = None): self.index = faiss.IndexFlatL2(dimension) self.id_by_index = {} self.chunk_by_index = {} + self.kvstore = kvstore + self.bank_id = bank_id + self.initialize() + + async def initialize(self) -> None: + if not self.kvstore: + return + + index_key = f"faiss_index:v1::{self.bank_id}" + stored_data = await self.kvstore.get(index_key) + + if stored_data: + data = json.loads(stored_data) + self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()} + self.chunk_by_index = { + int(k): Chunk.model_validate_json(v) + for k, v in data["chunk_by_index"].items() + } + + index_bytes = base64.b64decode(data["faiss_index"]) + self.index = faiss.deserialize_index(index_bytes) + + async def _save_index(self): + if not self.kvstore or not self.bank_id: + return + + index_bytes = faiss.serialize_index(self.index) + + data = { + "id_by_index": self.id_by_index, + "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, + "faiss_index": base64.b64encode(index_bytes).decode(), + } + + index_key = f"faiss_index:v1::{self.bank_id}" + await self.kvstore.set(key=index_key, value=json.dumps(data)) + + async def delete(self): + if not self.kvstore or not self.bank_id: + return + + await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") @tracing.span(name="add_chunks") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): @@ -51,6 +95,9 @@ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): self.index.add(np.array(embeddings).astype(np.float32)) + # Save updated index + await self._save_index() + async def query( self, embedding: NDArray, k: int, score_threshold: float ) -> QueryDocumentsResponse: @@ -85,7 +132,7 @@ async def initialize(self) -> None: for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore) ) self.cache[bank.identifier] = index @@ -110,13 +157,19 @@ async def register_memory_bank( # Store in cache index = BankWithIndex( - bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=memory_bank, + index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), ) self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}") + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 297eecbdcc..3b3f3868b4 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -93,6 +93,9 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 8d3d1f86dd..30745cb109 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -69,6 +69,9 @@ async def list_models(self) -> List[Model]: async def shutdown(self) -> None: pass + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 696cfb15d2..788f6cac40 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -58,6 +58,9 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 0611d9aa20..ac00fc7490 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -67,6 +67,9 @@ async def query( return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self): + await self.client.delete_collection(self.collection.name) + class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, url: str) -> None: @@ -134,6 +137,10 @@ async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 9acfef2dcb..44c2a8fe1f 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -112,6 +112,9 @@ async def query( return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self): + self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") + class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: PGVectorConfig) -> None: @@ -177,6 +180,10 @@ async def register_memory_bank( ) self.cache[memory_bank.identifier] = index + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + async def list_memory_banks(self) -> List[MemoryBank]: banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 97f0ac5761..0f07badfa4 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -54,4 +54,4 @@ async def test_update_model(self, inference_stack): assert updated_model.provider_resource_id != old_model.provider_resource_id # Cleanup - await models_impl.delete_model(model_id=model_id) + await models_impl.unregister_model(model_id=model_id) diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 24cef8a243..b6e2e0a76e 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid + import pytest from llama_stack.apis.memory import * # noqa: F403 @@ -43,9 +45,10 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks): +async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: + bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( - memory_bank_id="test_bank", + memory_bank_id=bank_id, params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks): class TestMemory: @pytest.mark.asyncio async def test_banks_list(self, memory_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack + + # Register a test bank + registered_bank = await register_memory_bank(banks_impl) + + try: + # Verify our bank shows up in list + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert any( + bank.memory_bank_id == registered_bank.memory_bank_id + for bank in response + ) + finally: + # Clean up + await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id) + + # Verify our bank was removed response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 0 + assert all( + bank.memory_bank_id != registered_bank.memory_bank_id for bank in response + ) @pytest.mark.asyncio async def test_banks_register(self, memory_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - await banks_impl.register_memory_bank( - memory_bank_id="test_bank_no_provider", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 - - # register same memory bank with same id again will fail - await banks_impl.register_memory_bank( - memory_bank_id="test_bank_no_provider", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + bank_id = f"test_bank_{uuid.uuid4().hex}" + + try: + # Register initial bank + await banks_impl.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + + # Verify our bank exists + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert any(bank.memory_bank_id == bank_id for bank in response) + + # Try registering same bank again + await banks_impl.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + + # Verify still only one instance of our bank + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert ( + len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1 + ) + finally: + # Clean up + await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio async def test_query_documents(self, memory_stack, sample_documents): @@ -102,17 +132,23 @@ async def test_query_documents(self, memory_stack, sample_documents): with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - await register_memory_bank(banks_impl) - await memory_impl.insert_documents("test_bank", sample_documents) + registered_bank = await register_memory_bank(banks_impl) + await memory_impl.insert_documents( + registered_bank.memory_bank_id, sample_documents + ) query1 = "programming language" - response1 = await memory_impl.query_documents("test_bank", query1) + response1 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query1 + ) assert_valid_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Test case 3: Query with semantic similarity query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents("test_bank", query3) + response3 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query3 + ) assert_valid_response(response3) assert any( "neural networks" in chunk.content.lower() for chunk in response3.chunks @@ -121,14 +157,18 @@ async def test_query_documents(self, memory_stack, sample_documents): # Test case 4: Query with limit on number of results query4 = "computer" params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents("test_bank", query4, params4) + response4 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query4, params4 + ) assert_valid_response(response4) assert len(response4.chunks) <= 2 # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document params5 = {"score_threshold": 0.2} - response5 = await memory_impl.query_documents("test_bank", query5, params5) + response5 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query5, params5 + ) assert_valid_response(response5) print("The scores are:", response5.scores) assert all(score >= 0.2 for score in response5.scores) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index ba7ed231ee..2bbf6cdd2b 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -145,6 +145,10 @@ async def query( ) -> QueryDocumentsResponse: raise NotImplementedError() + @abstractmethod + async def delete(self): + raise NotImplementedError() + @dataclass class BankWithIndex: