diff --git a/core/cat/auth/connection.py b/core/cat/auth/connection.py index 702a15ae1..9217abe99 100644 --- a/core/cat/auth/connection.py +++ b/core/cat/auth/connection.py @@ -41,7 +41,7 @@ async def __call__( # get protocol from Starlette request protocol = connection.scope.get('type') # extract credentials (user_id, token_or_key) from connection - user_id, credential = await self.extract_credentials(connection) + user_id, credential = self.extract_credentials(connection) auth_handlers = [ # try to get user from local idp connection.app.state.ccat.core_auth_handler, @@ -49,7 +49,7 @@ async def __call__( connection.app.state.ccat.custom_auth_handler, ] for ah in auth_handlers: - user: AuthUserInfo = await ah.authorize_user_from_credential( + user: AuthUserInfo = ah.authorize_user_from_credential( protocol, credential, self.resource, self.permission, user_id=user_id ) if user: @@ -59,7 +59,7 @@ async def __call__( self.not_allowed(connection) @abstractmethod - async def extract_credentials(self, connection: Request | WebSocket) -> Tuple[str] | None: + def extract_credentials(self, connection: Request | WebSocket) -> Tuple[str] | None: pass @abstractmethod @@ -73,7 +73,7 @@ def not_allowed(self, connection: Request | WebSocket): class HTTPAuth(ConnectionAuth): - async def extract_credentials(self, connection: Request) -> Tuple[str, str] | None: + def extract_credentials(self, connection: Request) -> Tuple[str, str] | None: """ Extract user_id and token/key from headers """ @@ -121,7 +121,7 @@ def not_allowed(self, connection: Request): class WebSocketAuth(ConnectionAuth): - async def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None: + def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None: """ Extract user_id from WebSocket path params Extract token from WebSocket query string @@ -166,7 +166,7 @@ def not_allowed(self, connection: WebSocket): class CoreFrontendAuth(HTTPAuth): - async def extract_credentials(self, connection: Request) -> Tuple[str, str] | None: + def extract_credentials(self, connection: Request) -> Tuple[str, str] | None: """ Extract user_id from cookie """ diff --git a/core/cat/factory/custom_auth_handler.py b/core/cat/factory/custom_auth_handler.py index ee0ed7570..6d3106379 100644 --- a/core/cat/factory/custom_auth_handler.py +++ b/core/cat/factory/custom_auth_handler.py @@ -20,7 +20,7 @@ class BaseAuthHandler(ABC): # TODOAUTH: pydantic model? MUST be implemented by subclasses. """ - async def authorize_user_from_credential( + def authorize_user_from_credential( self, protocol: Literal["http", "websocket"], credential: str, @@ -32,17 +32,17 @@ async def authorize_user_from_credential( ) -> AuthUserInfo | None: if is_jwt(credential): # JSON Web Token auth - return await self.authorize_user_from_jwt( + return self.authorize_user_from_jwt( credential, auth_resource, auth_permission ) else: # API_KEY auth - return await self.authorize_user_from_key( + return self.authorize_user_from_key( protocol, user_id, credential, auth_resource, auth_permission ) @abstractmethod - async def authorize_user_from_jwt( + def authorize_user_from_jwt( self, token: str, auth_resource: AuthResource, @@ -52,7 +52,7 @@ async def authorize_user_from_jwt( pass @abstractmethod - async def authorize_user_from_key( + def authorize_user_from_key( self, protocol: Literal["http", "websocket"], user_id: str, @@ -67,7 +67,7 @@ async def authorize_user_from_key( # Core auth handler, verify token on local idp class CoreAuthHandler(BaseAuthHandler): - async def authorize_user_from_jwt( + def authorize_user_from_jwt( self, token: str, auth_resource: AuthResource, auth_permission: AuthPermission ) -> AuthUserInfo | None: try: @@ -98,7 +98,7 @@ async def authorize_user_from_jwt( # do not pass return None - async def authorize_user_from_key( + def authorize_user_from_key( self, protocol: Literal["http", "websocket"], user_id: str, @@ -147,7 +147,7 @@ def _authorize_websocket_key(self, user_id: str, api_key: str, ws_key: str) -> A # No match -> deny access return None - async def issue_jwt(self, username: str, password: str) -> str | None: + def issue_jwt(self, username: str, password: str) -> str | None: # authenticate local user credentials and return a JWT token # brutal search over users, which are stored in a simple dictionary. @@ -178,10 +178,10 @@ async def issue_jwt(self, username: str, password: str) -> str | None: # Default Auth, always deny auth by default (only core auth decides). class CoreOnlyAuthHandler(BaseAuthHandler): - async def authorize_user_from_jwt(*args, **kwargs) -> AuthUserInfo | None: + def authorize_user_from_jwt(*args, **kwargs) -> AuthUserInfo | None: return None - async def authorize_user_from_key(*args, **kwargs) -> AuthUserInfo | None: + def authorize_user_from_key(*args, **kwargs) -> AuthUserInfo | None: return None diff --git a/core/cat/routes/auth.py b/core/cat/routes/auth.py index b6b32e417..8c8c4f32a 100644 --- a/core/cat/routes/auth.py +++ b/core/cat/routes/auth.py @@ -29,7 +29,7 @@ async def core_login_token(request: Request, response: Response): # use username and password to authenticate user from local identity provider and get token auth_handler = request.app.state.ccat.core_auth_handler - access_token = await auth_handler.issue_jwt( + access_token = auth_handler.issue_jwt( form_data["username"], form_data["password"] ) @@ -95,7 +95,7 @@ async def auth_token(request: Request, credentials: UserCredentials): # use username and password to authenticate user from local identity provider and get token auth_handler = request.app.state.ccat.core_auth_handler - access_token = await auth_handler.issue_jwt( + access_token = auth_handler.issue_jwt( credentials.username, credentials.password ) diff --git a/core/cat/routes/memory/points.py b/core/cat/routes/memory/points.py index 8ad65f563..6986cbe2d 100644 --- a/core/cat/routes/memory/points.py +++ b/core/cat/routes/memory/points.py @@ -1,13 +1,13 @@ from typing import Dict, List from pydantic import BaseModel -from fastapi import Query, Request, APIRouter, HTTPException, Depends +from fastapi import Query, Body, Request, APIRouter, HTTPException, Depends import time from cat.auth.connection import HTTPAuth from cat.auth.permissions import AuthPermission, AuthResource from cat.memory.vector_memory import VectorMemory from cat.looking_glass.stray_cat import StrayCat - +from cat.log import log class MemoryPointBase(BaseModel): content: str @@ -24,7 +24,7 @@ class MemoryPoint(MemoryPointBase): # GET memories from recall -@router.get("/recall") +@router.get("/recall", deprecated=True) async def recall_memory_points_from_text( request: Request, text: str = Query(description="Find memories similar to this text."), @@ -32,6 +32,7 @@ async def recall_memory_points_from_text( stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)), ) -> Dict: """Search k memories similar to given text.""" + log.warning("Deprecated: This endpoint will be removed in the next major version.") # Embed the query to plot it in the Memory page query_embedding = stray.embedder.embed_query(text) @@ -76,6 +77,92 @@ async def recall_memory_points_from_text( }, } +# POST memories from recall +@router.post("/recall") +async def recall_memory_points( + request: Request, + text: str = Body(description="Find memories similar to this text."), + k: int = Body(default=100, description="How many memories to return."), + metadata: Dict = Body(default={}, + description="Flat dictionary where each key-value pair represents a filter." + "The memory points returned will match the specified metadata criteria." + ), + stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)), +) -> Dict: + """Search k memories similar to given text with specified metadata criteria. + + Example + ---------- + ``` + collection = "episodic" + content = "MIAO!" + metadata = {"custom_key": "custom_value"} + req_json = { + "content": content, + "metadata": metadata, + } + # create a point + res = requests.post( + f"http://localhost:1865/memory/collections/{collection}/points", json=req_json + ) + + # recall with metadata + req_json = { + "text": "CAT", + "metadata":{"custom_key":"custom_value"} + } + res = requests.post( + f"http://localhost:1865/memory/recall", json=req_json + ) + json = res.json() + print(json) + ``` + + """ + + # Embed the query to plot it in the Memory page + query_embedding = stray.embedder.embed_query(text) + query = { + "text": text, + "vector": query_embedding, + } + + # Loop over collections and retrieve nearby memories + collections = list( + stray.memory.vectors.collections.keys() + ) + recalled = {} + for c in collections: + # only episodic collection has users + user_id = stray.user_id + if c == "episodic": + metadata["source"] = user_id + else: + metadata.pop("source", None) + + memories = stray.memory.vectors.collections[c].recall_memories_from_embedding( + query_embedding, k=k, metadata=metadata + ) + + recalled[c] = [] + for metadata_memories, score, vector, id in memories: + memory_dict = dict(metadata_memories) + memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed + memory_dict["id"] = id + memory_dict["score"] = float(score) + memory_dict["vector"] = vector + recalled[c].append(memory_dict) + + return { + "query": query, + "vectors": { + "embedder": str( + stray.embedder.__class__.__name__ + ), # TODO: should be the config class name + "collections": recalled, + }, + } + # CREATE a point in memory @router.post("/collections/{collection_id}/points", response_model=MemoryPoint) async def create_memory_point( diff --git a/core/tests/routes/auth/test_jwt.py b/core/tests/routes/auth/test_jwt.py index f44f94fcc..5d94dbf74 100644 --- a/core/tests/routes/auth/test_jwt.py +++ b/core/tests/routes/auth/test_jwt.py @@ -33,8 +33,7 @@ def test_refuse_issue_jwt(client): assert json["detail"]["error"] == "Invalid Credentials" -@pytest.mark.asyncio # to test async functions -async def test_issue_jwt(client): +def test_issue_jwt(client): creds = { "username": "admin", "password": "admin" @@ -49,7 +48,7 @@ async def test_issue_jwt(client): # is the JWT correct for core auth handler? auth_handler = client.app.state.ccat.core_auth_handler - user_info = await auth_handler.authorize_user_from_jwt( + user_info = auth_handler.authorize_user_from_jwt( received_token, AuthResource.LLM, AuthPermission.WRITE ) assert len(user_info.id) == 36 and len(user_info.id.split("-")) == 5 # uuid4 @@ -70,8 +69,7 @@ async def test_issue_jwt(client): assert False -@pytest.mark.asyncio -async def test_issue_jwt_for_new_user(client): +def test_issue_jwt_for_new_user(client): # create new user creds = { diff --git a/core/tests/routes/memory/test_memory_recall.py b/core/tests/routes/memory/test_memory_recall.py index f37df52e3..dd3860e5b 100644 --- a/core/tests/routes/memory/test_memory_recall.py +++ b/core/tests/routes/memory/test_memory_recall.py @@ -4,7 +4,7 @@ # search on default startup memory def test_memory_recall_default_success(client): params = {"text": "Red Queen"} - response = client.get("/memory/recall/", params=params) + response = client.post("/memory/recall/", json=params) json = response.json() assert response.status_code == 200 @@ -30,7 +30,7 @@ def test_memory_recall_default_success(client): # search without query should throw error def test_memory_recall_without_query_error(client): - response = client.get("/memory/recall") + response = client.post("/memory/recall") assert response.status_code == 400 @@ -42,7 +42,7 @@ def test_memory_recall_success(client): # recall params = {"text": "Red Queen"} - response = client.get("/memory/recall/", params=params) + response = client.post("/memory/recall/", json=params) json = response.json() assert response.status_code == 200 episodic_memories = json["vectors"]["collections"]["episodic"] @@ -58,8 +58,51 @@ def test_memory_recall_with_k_success(client): # recall at max k memories max_k = 2 params = {"k": max_k, "text": "Red Queen"} - response = client.get("/memory/recall/", params=params) + response = client.post("/memory/recall/", json=params) json = response.json() assert response.status_code == 200 episodic_memories = json["vectors"]["collections"]["episodic"] assert len(episodic_memories) == max_k # only 2 of 6 memories recalled + +# search with query and metadata +def test_memory_recall_with_metadata(client): + messages = [ + { + "content": "MIAO_1", + "metadata": {"key_1":"v1","key_2":"v2"}, + }, + { + "content": "MIAO_2", + "metadata": {"key_1":"v1","key_2":"v3"}, + }, + { + "content": "MIAO_3", + "metadata": {}, + } + ] + + # insert a new points with metadata + for req_json in messages: + client.post( + "/memory/collections/episodic/points", json=req_json + ) + + # recall with metadata + params = {"text": "MIAO", "metadata":{"key_1":"v1"}} + response = client.post("/memory/recall/", json=params) + json = response.json() + assert response.status_code == 200 + episodic_memories = json["vectors"]["collections"]["episodic"] + assert len(episodic_memories) == 2 + + # recall with metadata multiple keys in metadata + params = {"text": "MIAO", "metadata":{"key_1":"v1","key_2":"v2"}} + response = client.post("/memory/recall/", json=params) + json = response.json() + assert response.status_code == 200 + episodic_memories = json["vectors"]["collections"]["episodic"] + assert len(episodic_memories) == 1 + assert episodic_memories[0]["page_content"] == "MIAO_1" + + +