Skip to content

Commit

Permalink
delete episodes
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Feb 5, 2025
1 parent 0102c05 commit a577efe
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
from datetime import datetime
from time import time
from typing import LiteralString

Check failure on line 20 in graphiti_core/graphiti.py

View workflow job for this annotation

GitHub Actions / mypy

attr-defined

Module "typing" has no attribute "LiteralString"

Check notice on line 20 in graphiti_core/graphiti.py

View workflow job for this annotation

GitHub Actions / mypy

Note

Use `from typing_extensions import LiteralString` instead

Check notice on line 20 in graphiti_core/graphiti.py

View workflow job for this annotation

GitHub Actions / mypy

Note

See https://mypy.readthedocs.io/en/stable/runtime_troubles.html#using-new-additions-to-the-typing-module

from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
Expand Down Expand Up @@ -749,13 +750,32 @@ async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_no
)

async def remove_episode(self, episode_uuid: str):
# find the episode to be deleted
# Find the episode to be deleted
episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)

# find the edges mentioned by the episode
# Find edges mentioned by the episode
edges = await EntityEdge.get_by_uuids(self.driver, episode.edge_uuids)

# We should only delete edges created by the episode
edges_to_delete: list[EntityEdge] = []
for edge in edges:
if edge.episodes[0] == episode.uuid:
edges_to_delete.append(edge)

# Find nodes mentioned by the episode
nodes = await get_mentioned_nodes(self.driver, episode)
# We should delete all node that are only mentioned in the deleted episode
nodes_to_delete: list[EntityNode] = []
for node in nodes:
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
records, _, _ = await self.driver.execute_query(
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
)

for record in records:
if record['episode_count'] == 1:
nodes_to_delete.append(node)

await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
await episode.delete(self.driver)

0 comments on commit a577efe

Please sign in to comment.