From b537cf56e5d647a15b96ca3ec441757fed5bbbf0 Mon Sep 17 00:00:00 2001 From: Pavlo Paliychuk Date: Tue, 24 Sep 2024 20:08:09 -0400 Subject: [PATCH] chore: Make deleting groups safer (#155) * chore: Make deleting groups safer * chore: Use appropriate errors in delete group checks * chore: Add GroupsEdgesNotFound error type --- graphiti_core/edges.py | 8 +++---- graphiti_core/errors.py | 8 +++++++ pyproject.toml | 2 +- server/graph_service/zep_graphiti.py | 33 +++++++++++++++++----------- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 142c0381..2b326a1c 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -24,7 +24,7 @@ from neo4j import AsyncDriver from pydantic import BaseModel, Field -from graphiti_core.errors import EdgeNotFoundError +from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.nodes import Node @@ -147,10 +147,9 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): ) edges = [get_episodic_edge_from_record(record) for record in records] - uuids = [edge.uuid for edge in edges] if len(edges) == 0: - raise EdgeNotFoundError(uuids[0]) + raise GroupsEdgesNotFoundError(group_ids) return edges @@ -293,10 +292,9 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): ) edges = [get_entity_edge_from_record(record) for record in records] - uuids = [edge.uuid for edge in edges] if len(edges) == 0: - raise EdgeNotFoundError(uuids[0]) + raise GroupsEdgesNotFoundError(group_ids) return edges diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index 84737419..58a8ee3b 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -27,6 +27,14 @@ def __init__(self, uuid: str): super().__init__(self.message) +class GroupsEdgesNotFoundError(GraphitiError): + """Raised when no edges are found for a list of group ids.""" + + def __init__(self, group_ids: list[str]): + self.message = f'no edges found for group ids {group_ids}' + super().__init__(self.message) + + class NodeNotFoundError(GraphitiError): """Raised when a node is not found.""" diff --git a/pyproject.toml b/pyproject.toml index 304b397a..c8ccfcaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.3.5" +version = "0.3.6" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 66457130..4a901d61 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -1,15 +1,18 @@ +import logging from typing import Annotated from fastapi import Depends, HTTPException from graphiti_core import Graphiti # type: ignore from graphiti_core.edges import EntityEdge # type: ignore -from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore +from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError, NodeNotFoundError from graphiti_core.llm_client import LLMClient # type: ignore from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore from graph_service.config import ZepEnvDep from graph_service.dto import FactResult +logger = logging.getLogger(__name__) + class ZepGraphiti(Graphiti): def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): @@ -36,18 +39,22 @@ async def get_entity_edge(self, uuid: str): async def delete_group(self, group_id: str): try: edges = await EntityEdge.get_by_group_ids(self.driver, [group_id]) - nodes = await EntityNode.get_by_group_ids(self.driver, [group_id]) - episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id]) - for edge in edges: - await edge.delete(self.driver) - for node in nodes: - await node.delete(self.driver) - for episode in episodes: - await episode.delete(self.driver) - except EdgeNotFoundError as e: - raise HTTPException(status_code=404, detail=e.message) from e - except NodeNotFoundError as e: - raise HTTPException(status_code=404, detail=e.message) from e + except GroupsEdgesNotFoundError: + logger.warning(f'No edges found for group {group_id}') + edges = [] + + nodes = await EntityNode.get_by_group_ids(self.driver, [group_id]) + + episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id]) + + for edge in edges: + await edge.delete(self.driver) + + for node in nodes: + await node.delete(self.driver) + + for episode in episodes: + await episode.delete(self.driver) async def delete_entity_edge(self, uuid: str): try: