From 4b1b1962511329c838a589fc8fd41f97083ebd65 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 15:30:17 -0800 Subject: [PATCH 1/6] add model update and delete --- docs/resources/llama-stack-spec.html | 172 ++++++++++++++++-- docs/resources/llama-stack-spec.yaml | 108 +++++++++-- llama_stack/apis/models/client.py | 32 +++- llama_stack/apis/models/models.py | 12 ++ .../distribution/routers/routing_tables.py | 31 ++++ llama_stack/distribution/store/registry.py | 12 ++ 6 files changed, 337 insertions(+), 30 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ef9e29af8..7fb46a7246 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" }, "servers": [ { @@ -429,6 +429,39 @@ } } }, + "/models/delete": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteModelRequest" + } + } + }, + "required": true + } + } + }, "/inference/embeddings": { "post": { "responses": { @@ -2225,6 +2258,46 @@ "required": true } } + }, + "/models/update": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Model" + } + } + } + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateModelRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -4549,6 +4622,18 @@ "session_id" ] }, + "DeleteModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] + }, "EmbeddingsRequest": { "type": "object", "properties": { @@ -7826,6 +7911,49 @@ "synthetic_data" ], "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." + }, + "UpdateModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + }, + "provider_model_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] } }, "responses": {} @@ -7837,52 +7965,52 @@ ], "tags": [ { - "name": "Inspect" + "name": "Datasets" }, { - "name": "Models" + "name": "Inference" }, { - "name": "Eval" + "name": "ScoringFunctions" }, { - "name": "EvalTasks" + "name": "MemoryBanks" }, { - "name": "Scoring" + "name": "Telemetry" }, { - "name": "Inference" + "name": "PostTraining" }, { - "name": "Memory" + "name": "Models" }, { - "name": "Safety" + "name": "Inspect" }, { - "name": "PostTraining" + "name": "Safety" }, { - "name": "ScoringFunctions" + "name": "Scoring" }, { - "name": "Telemetry" + "name": "BatchInference" }, { - "name": "Shields" + "name": "Eval" }, { - "name": "BatchInference" + "name": "SyntheticDataGeneration" }, { - "name": "MemoryBanks" + "name": "EvalTasks" }, { - "name": "Datasets" + "name": "Shields" }, { - "name": "SyntheticDataGeneration" + "name": "Memory" }, { "name": "DatasetIO" @@ -8142,6 +8270,10 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, + { + "name": "DeleteModelRequest", + "description": "" + }, { "name": "EmbeddingsRequest", "description": "" @@ -8453,6 +8585,10 @@ { "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" + }, + { + "name": "UpdateModelRequest", + "description": "" } ], "x-tagGroups": [ @@ -8521,6 +8657,7 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", + "DeleteModelRequest", "DoraFinetuningConfig", "EmbeddingsRequest", "EmbeddingsResponse", @@ -8618,6 +8755,7 @@ "Turn", "URL", "UnstructuredLogEvent", + "UpdateModelRequest", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 14f87cf547..06a4afa855 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -867,6 +867,14 @@ components: - agent_id - session_id type: object + DeleteModelRequest: + additionalProperties: false + properties: + model_id: + type: string + required: + - model_id + type: object DoraFinetuningConfig: additionalProperties: false properties: @@ -3272,6 +3280,28 @@ components: - message - severity type: object + UpdateModelRequest: + additionalProperties: false + properties: + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model_id: + type: string + provider_id: + type: string + provider_model_id: + type: string + required: + - model_id + type: object UserMessage: additionalProperties: false properties: @@ -3384,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + \ draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4186,6 +4216,27 @@ paths: responses: {} tags: - MemoryBanks + /models/delete: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeleteModelRequest' + required: true + responses: + '200': + description: OK + tags: + - Models /models/get: get: parameters: @@ -4256,6 +4307,31 @@ paths: description: OK tags: - Models + /models/update: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateModelRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Model' + description: OK + tags: + - Models /post_training/job/artifacts: get: parameters: @@ -4748,22 +4824,22 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inspect -- name: Models -- name: Eval -- name: EvalTasks -- name: Scoring +- name: Datasets - name: Inference -- name: Memory -- name: Safety -- name: PostTraining - name: ScoringFunctions +- name: MemoryBanks - name: Telemetry -- name: Shields +- name: PostTraining +- name: Models +- name: Inspect +- name: Safety +- name: Scoring - name: BatchInference -- name: MemoryBanks -- name: Datasets +- name: Eval - name: SyntheticDataGeneration +- name: EvalTasks +- name: Shields +- name: Memory - name: DatasetIO - name: Agents - description: @@ -4964,6 +5040,9 @@ tags: - description: name: DeleteAgentsSessionRequest +- description: + name: DeleteModelRequest - description: name: EmbeddingsRequest @@ -5194,6 +5273,9 @@ tags: ' name: SyntheticDataGenerationResponse +- description: + name: UpdateModelRequest x-tagGroups: - name: Operations tags: @@ -5256,6 +5338,7 @@ x-tagGroups: - Dataset - DeleteAgentsRequest - DeleteAgentsSessionRequest + - DeleteModelRequest - DoraFinetuningConfig - EmbeddingsRequest - EmbeddingsResponse @@ -5353,6 +5436,7 @@ x-tagGroups: - Turn - URL - UnstructuredLogEvent + - UpdateModelRequest - UserMessage - VectorMemoryBank - VectorMemoryBankParams diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index d986828ee1..aa63ca5412 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -7,7 +7,7 @@ import asyncio import json -from typing import List, Optional +from typing import Any, Dict, List, Optional import fire import httpx @@ -61,6 +61,36 @@ async def get_model(self, identifier: str) -> Optional[Model]: return None return Model(**j) + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + async with httpx.AsyncClient() as client: + response = await client.put( + f"{self.base_url}/models/update", + json={ + "model_id": model_id, + "provider_model_id": provider_model_id, + "provider_id": provider_id, + "metadata": metadata, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return Model(**response.json()) + + async def delete_model(self, model_id: str) -> None: + async with httpx.AsyncClient() as client: + response = await client.delete( + f"{self.base_url}/models/delete", + params={"model_id": model_id}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + async def run_main(host: str, port: int, stream: bool): client = ModelsClient(f"http://{host}:{port}") diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 2cd12b4bc8..7eebe5b9f8 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -54,3 +54,15 @@ async def register_model( provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... + + @webmethod(route="/models/update", method="PUT") + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: ... + + @webmethod(route="/models/delete", method="DELETE") + async def delete_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 8c1b0c1e71..32a341278c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -152,6 +152,10 @@ async def get_object_by_identifier( assert len(objects) == 1 return objects[0] + async def delete_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + # TODO: delete from provider + async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: @@ -225,6 +229,33 @@ async def register_model( registered_model = await self.register_object(model) return registered_model + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + + updated_model = Model( + identifier=model_id, + provider_resource_id=provider_model_id + or existing_model.provider_resource_id, + provider_id=provider_id or existing_model.provider_id, + metadata=metadata or existing_model.metadata, + ) + registered_model = await self.register_object(updated_model) + return registered_model + + async def delete_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.delete_object(existing_model) + class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[Shield]: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index bb87c81fa8..35276b4395 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,6 +36,8 @@ def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... # The current approach could lead to inconsistencies if the same logical object has different data across providers. async def register(self, obj: RoutableObjectWithProvider) -> bool: ... + async def delete(self, type: str, identifier: str) -> None: ... + REGISTER_PREFIX = "distributions:registry" KEY_VERSION = "v1" @@ -120,6 +122,9 @@ async def register(self, obj: RoutableObjectWithProvider) -> bool: ) return True + async def delete(self, type: str, identifier: str) -> None: + await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier)) + class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): @@ -206,6 +211,13 @@ async def register(self, obj: RoutableObjectWithProvider) -> bool: return success + async def delete(self, type: str, identifier: str) -> None: + await super().delete(type, identifier) + cache_key = (type, identifier) + async with self._locked_cache() as cache: + if cache_key in cache: + del cache[cache_key] + async def create_dist_registry( metadata_store: Optional[KVStoreConfig], From 9e68ed3f36aafa3fcb40378d10a6ccb7a5a99a4e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 20:50:26 -0800 Subject: [PATCH 2/6] registery to handle updates and deletes --- docs/resources/llama-stack-spec.html | 32 ++--- docs/resources/llama-stack-spec.yaml | 24 ++-- llama_stack/apis/models/models.py | 2 +- .../distribution/routers/routing_tables.py | 39 +++--- llama_stack/distribution/store/registry.py | 114 ++++++++---------- 5 files changed, 102 insertions(+), 109 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7fb46a7246..3cac939672 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" }, "servers": [ { @@ -7965,10 +7965,7 @@ ], "tags": [ { - "name": "Datasets" - }, - { - "name": "Inference" + "name": "Shields" }, { "name": "ScoringFunctions" @@ -7977,28 +7974,31 @@ "name": "MemoryBanks" }, { - "name": "Telemetry" + "name": "Datasets" }, { - "name": "PostTraining" + "name": "Agents" }, { - "name": "Models" + "name": "DatasetIO" + }, + { + "name": "Inference" }, { "name": "Inspect" }, { - "name": "Safety" + "name": "Memory" }, { - "name": "Scoring" + "name": "Models" }, { - "name": "BatchInference" + "name": "PostTraining" }, { - "name": "Eval" + "name": "Safety" }, { "name": "SyntheticDataGeneration" @@ -8007,16 +8007,16 @@ "name": "EvalTasks" }, { - "name": "Shields" + "name": "Scoring" }, { - "name": "Memory" + "name": "BatchInference" }, { - "name": "DatasetIO" + "name": "Eval" }, { - "name": "Agents" + "name": "Telemetry" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 06a4afa855..5d2b91d842 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3414,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" + \ draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4824,24 +4824,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Datasets -- name: Inference +- name: Shields - name: ScoringFunctions - name: MemoryBanks -- name: Telemetry -- name: PostTraining -- name: Models +- name: Datasets +- name: Agents +- name: DatasetIO +- name: Inference - name: Inspect +- name: Memory +- name: Models +- name: PostTraining - name: Safety +- name: SyntheticDataGeneration +- name: EvalTasks - name: Scoring - name: BatchInference - name: Eval -- name: SyntheticDataGeneration -- name: EvalTasks -- name: Shields -- name: Memory -- name: DatasetIO -- name: Agents +- name: Telemetry - description: name: BuiltinTool - description: Model: ... - @webmethod(route="/models/delete", method="DELETE") + @webmethod(route="/models/delete", method="POST") async def delete_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 32a341278c..861c830be5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -124,8 +124,8 @@ def apiname_object(): apiname, objtype = apiname_object() # Get objects from disk registry - objects = self.dist_registry.get_cached(objtype, routing_key) - if not objects: + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" @@ -135,9 +135,8 @@ def apiname_object(): f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) - for obj in objects: - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") @@ -145,30 +144,36 @@ async def get_object_by_identifier( self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(type, identifier) - if not objects: + obj = await self.dist_registry.get(type, identifier) + if not obj: return None - assert len(objects) == 1 - return objects[0] + return obj async def delete_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) # TODO: delete from provider + async def update_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: + registered_obj = await register_object_with_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) + return await self.dist_registry.update(registered_obj) + async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.type, obj.identifier) + existing_obj = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration - for existing_obj in existing_objects: - if existing_obj.provider_id == obj.provider_id or not obj.provider_id: - print( - f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" - ) - return existing_obj + if existing_obj and existing_obj.provider_id == obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" + ) + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -247,7 +252,7 @@ async def update_model( provider_id=provider_id or existing_model.provider_id, metadata=metadata or existing_model.metadata, ) - registered_model = await self.register_object(updated_model) + registered_model = await self.update_object(updated_model) return registered_model async def delete_model(self, model_id: str) -> None: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 35276b4395..d8a1a04e3c 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -26,9 +26,13 @@ async def get_all(self) -> List[RoutableObjectWithProvider]: ... async def initialize(self) -> None: ... - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... + + async def update( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: ... # The current data structure allows multiple objects with the same identifier but different providers. # This is not ideal - we should have a single object that can be served by multiple providers, @@ -40,7 +44,7 @@ async def delete(self, type: str, identifier: str) -> None: ... REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v1" +KEY_VERSION = "v2" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" @@ -54,19 +58,11 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - try: - objects_data = json.loads(value) - objects = [ - pydantic.parse_obj_as( - RoutableObjectWithProvider, - json.loads(obj_str), - ) - for obj_str in objects_data - ] - all_objects.extend(objects) - except Exception as e: - print(f"Error parsing value: {e}") - traceback.print_exc() + obj = pydantic.parse_obj_as( + RoutableObjectWithProvider, + json.loads(value), + ) + all_objects.append(obj) return all_objects @@ -79,46 +75,49 @@ async def initialize(self) -> None: def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: + ) -> Optional[RoutableObjectWithProvider]: # Disk registry does not have a cache - return [] + raise NotImplementedError("Disk registry does not have a cache") async def get_all(self) -> List[RoutableObjectWithProvider]: start_key, end_key = _get_registry_key_range() values = await self.kvstore.range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: json_str = await self.kvstore.get( KEY_FORMAT.format(type=type, identifier=identifier) ) if not json_str: - return [] + return None objects_data = json.loads(json_str) - return [ - pydantic.parse_obj_as( + # Return only the first object if any exist + if objects_data: + return pydantic.parse_obj_as( RoutableObjectWithProvider, - json.loads(obj_str), + json.loads(objects_data), ) - for obj_str in objects_data - ] + return None + + async def update(self, obj: RoutableObjectWithProvider) -> None: + await self.kvstore.set( + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + obj.model_dump_json(), + ) + return obj async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.type, obj.identifier) + existing_obj = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists - for eobj in existing_objects: - if eobj.provider_id == obj.provider_id: - return False - - existing_objects.append(obj) + if existing_obj and existing_obj.provider_id == obj.provider_id: + return False - objects_json = [ - obj.model_dump_json() for obj in existing_objects - ] # Fixed variable name await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), - json.dumps(objects_json), + obj.model_dump_json(), ) return True @@ -129,7 +128,7 @@ async def delete(self, type: str, identifier: str) -> None: class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {} self._initialized = False self._initialize_lock = asyncio.Lock() self._cache_lock = asyncio.Lock() @@ -156,13 +155,7 @@ async def _ensure_initialized(self): async with self._locked_cache() as cache: for obj in objects: cache_key = (obj.type, obj.identifier) - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj self._initialized = True @@ -171,28 +164,22 @@ async def initialize(self) -> None: def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: - return self.cache.get((type, identifier), [])[:] # Return a copy + ) -> Optional[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), None) async def get_all(self) -> List[RoutableObjectWithProvider]: await self._ensure_initialized() async with self._locked_cache() as cache: - return [item for sublist in cache.values() for item in sublist] + return list(cache.values()) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: await self._ensure_initialized() cache_key = (type, identifier) async with self._locked_cache() as cache: - if cache_key in cache: - return cache[cache_key][:] - - objects = await super().get(type, identifier) - if objects: - async with self._locked_cache() as cache: - cache[cache_key] = objects - - return objects + return cache.get(cache_key, None) async def register(self, obj: RoutableObjectWithProvider) -> bool: await self._ensure_initialized() @@ -201,16 +188,17 @@ async def register(self, obj: RoutableObjectWithProvider) -> bool: if success: cache_key = (obj.type, obj.identifier) async with self._locked_cache() as cache: - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj return success + async def update(self, obj: RoutableObjectWithProvider) -> None: + await super().update(obj) + cache_key = (obj.type, obj.identifier) + async with self._locked_cache() as cache: + cache[cache_key] = obj + return obj + async def delete(self, type: str, identifier: str) -> None: await super().delete(type, identifier) cache_key = (type, identifier) From 0ba11b82bed954ab59cec6e72ef0f4f1812d5fa5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 21:06:16 -0800 Subject: [PATCH 3/6] make update a POST --- docs/resources/llama-stack-spec.html | 36 ++++++++++++++-------------- docs/resources/llama-stack-spec.yaml | 20 ++++++++-------- llama_stack/apis/models/models.py | 2 +- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 3cac939672..44554f2ff4 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 21:05:58.323310" }, "servers": [ { @@ -7965,58 +7965,58 @@ ], "tags": [ { - "name": "Shields" + "name": "Agents" }, { - "name": "ScoringFunctions" + "name": "DatasetIO" }, { - "name": "MemoryBanks" + "name": "Models" }, { - "name": "Datasets" + "name": "Inference" }, { - "name": "Agents" + "name": "BatchInference" }, { - "name": "DatasetIO" + "name": "Memory" }, { - "name": "Inference" + "name": "Safety" }, { "name": "Inspect" }, { - "name": "Memory" + "name": "EvalTasks" }, { - "name": "Models" + "name": "Scoring" }, { - "name": "PostTraining" + "name": "Datasets" }, { - "name": "Safety" + "name": "PostTraining" }, { - "name": "SyntheticDataGeneration" + "name": "Eval" }, { - "name": "EvalTasks" + "name": "Shields" }, { - "name": "Scoring" + "name": "Telemetry" }, { - "name": "BatchInference" + "name": "ScoringFunctions" }, { - "name": "Eval" + "name": "MemoryBanks" }, { - "name": "Telemetry" + "name": "SyntheticDataGeneration" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 5d2b91d842..fc28405d76 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3414,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" + \ draft and subject to change.\n Generated at 2024-11-13 21:05:58.323310" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4824,24 +4824,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Shields -- name: ScoringFunctions -- name: MemoryBanks -- name: Datasets - name: Agents - name: DatasetIO +- name: Models - name: Inference -- name: Inspect +- name: BatchInference - name: Memory -- name: Models -- name: PostTraining - name: Safety -- name: SyntheticDataGeneration +- name: Inspect - name: EvalTasks - name: Scoring -- name: BatchInference +- name: Datasets +- name: PostTraining - name: Eval +- name: Shields - name: Telemetry +- name: ScoringFunctions +- name: MemoryBanks +- name: SyntheticDataGeneration - description: name: BuiltinTool - description: Model: ... - @webmethod(route="/models/update", method="PUT") + @webmethod(route="/models/update", method="POST") async def update_model( self, model_id: str, From 89342d352c47a7a27f9bdbbd53610b513a71c9e3 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 21:09:18 -0800 Subject: [PATCH 4/6] remove comment --- llama_stack/distribution/store/registry.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index d8a1a04e3c..b876ee7565 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -34,10 +34,6 @@ async def update( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: ... - # The current data structure allows multiple objects with the same identifier but different providers. - # This is not ideal - we should have a single object that can be served by multiple providers, - # suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider]. - # The current approach could lead to inconsistencies if the same logical object has different data across providers. async def register(self, obj: RoutableObjectWithProvider) -> bool: ... async def delete(self, type: str, identifier: str) -> None: ... From 43af05d851f36d08ee4ae3b8162ef37e062c7a02 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 21:36:24 -0800 Subject: [PATCH 5/6] add tests --- .../inference/test_model_registration.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 4b20e519c4..f2fa4bb9a5 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -33,3 +33,22 @@ async def test_register_nonexistent_model(self, inference_stack): await models_impl.register_model( model_id="Llama3-NonExistent-Model", ) + + @pytest.mark.asyncio + async def test_update_model(self, inference_stack): + _, models_impl = inference_stack + + # Register a model to update + model_id = "Llama3.1-8B-Instruct" + await models_impl.register_model(model_id=model_id) + + # Update the model + new_provider_id = "updated_provider" + await models_impl.update_model(model_id=model_id, provider_id=new_provider_id) + + # Retrieve the updated model to verify changes + updated_model = await models_impl.get_model(model_id) + assert updated_model.provider_id == new_provider_id + + # Cleanup + await models_impl.delete_model(model_id=model_id) From 05535698e2eedd704a32a140f62611779ba8c3a8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 21:54:43 -0800 Subject: [PATCH 6/6] fix tests --- .../tests/inference/test_model_registration.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index f2fa4bb9a5..97f0ac5761 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -6,6 +6,8 @@ import pytest +from llama_models.datatypes import CoreModelId + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py @@ -39,16 +41,17 @@ async def test_update_model(self, inference_stack): _, models_impl = inference_stack # Register a model to update - model_id = "Llama3.1-8B-Instruct" - await models_impl.register_model(model_id=model_id) + model_id = CoreModelId.llama3_1_8b_instruct.value + old_model = await models_impl.register_model(model_id=model_id) # Update the model - new_provider_id = "updated_provider" - await models_impl.update_model(model_id=model_id, provider_id=new_provider_id) + new_model_id = CoreModelId.llama3_2_3b_instruct.value + updated_model = await models_impl.update_model( + model_id=model_id, provider_model_id=new_model_id + ) # Retrieve the updated model to verify changes - updated_model = await models_impl.get_model(model_id) - assert updated_model.provider_id == new_provider_id + assert updated_model.provider_resource_id != old_model.provider_resource_id # Cleanup await models_impl.delete_model(model_id=model_id)