Skip to content

Commit

Permalink
Add group ids (#89)
Browse files Browse the repository at this point in the history
* set and retrieve group ids

* update add episode with group id support

* add episode and search functional

* update bulk

* mypy updates

* remove unused imports

* update unit tests

* unit tests

* add optional uuid field

* format

* mypy

* ellipsis
  • Loading branch information
prasmussen15 authored Sep 6, 2024
1 parent c7fc057 commit 42fb590
Show file tree
Hide file tree
Showing 15 changed files with 329 additions and 356 deletions.
1 change: 1 addition & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
)
return

Expand Down
71 changes: 39 additions & 32 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4

from neo4j import AsyncDriver
Expand All @@ -32,6 +33,7 @@

class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex)
group_id: str | None = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
Expand Down Expand Up @@ -61,11 +63,12 @@ async def save(self, driver: AsyncDriver):
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)

Expand All @@ -92,25 +95,16 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuid=uuid,
)

edges: list[EpisodicEdge] = []

for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
edges = [get_episodic_edge_from_record(record) for record in records]

logger.info(f'Found Edge: {uuid}')

Expand Down Expand Up @@ -153,14 +147,15 @@ async def save(self, driver: AsyncDriver):
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
Expand Down Expand Up @@ -198,6 +193,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
Expand All @@ -208,25 +204,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
uuid=uuid,
)

edges: list[EntityEdge] = []

for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
edges = [get_entity_edge_from_record(record) for record in records]

logger.info(f'Found Edge: {uuid}')

return edges[0]


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)


def get_entity_edge_from_record(record: Any) -> EntityEdge:
return EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
75 changes: 45 additions & 30 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
from datetime import datetime
from time import time
from typing import Callable

from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
Expand Down Expand Up @@ -120,7 +119,7 @@ def close(self):
Parameters
----------
None
self
Returns
-------
Expand Down Expand Up @@ -151,7 +150,7 @@ async def build_indices_and_constraints(self):
Parameters
----------
None
self
Returns
-------
Expand All @@ -178,6 +177,7 @@ async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand All @@ -191,6 +191,8 @@ async def retrieve_episodes(
The reference time to retrieve episodes before.
last_n : int, optional
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
group_ids : list[str | None], optional
The group ids to return data from.
Returns
-------
Expand All @@ -202,7 +204,7 @@ async def retrieve_episodes(
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n)
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)

async def add_episode(
self,
Expand All @@ -211,8 +213,8 @@ async def add_episode(
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
group_id: str | None = None,
uuid: str | None = None,
):
"""
Process an episode and update the graph.
Expand All @@ -232,10 +234,10 @@ async def add_episode(
The reference time for the episode.
source : EpisodeType, optional
The type of the episode. Defaults to EpisodeType.message.
success_callback : Callable | None, optional
A callback function to be called upon successful processing.
error_callback : Callable | None, optional
A callback function to be called if an error occurs during processing.
group_id : str | None
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
Returns
-------
Expand Down Expand Up @@ -266,16 +268,20 @@ async def add_episode_endpoint(episode_data: EpisodeData):
embedder = self.llm_client.get_embedder()
now = datetime.now()

previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
)
episode = EpisodicNode(
name=name,
group_id=group_id,
labels=[],
source=source,
content=episode_body,
source_description=source_description,
created_at=now,
valid_at=reference_time,
)
episode.uuid = uuid if uuid is not None else episode.uuid

# Extract entities as nodes

Expand All @@ -299,7 +305,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):

(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
Expand Down Expand Up @@ -388,11 +396,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)

logger.info(f'Built episodic edges: {episodic_edges}')

Expand All @@ -405,18 +409,10 @@ async def add_episode_endpoint(episode_data: EpisodeData):
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')

if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
):
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
"""
Process multiple episodes in bulk and update the graph.
Expand All @@ -427,6 +423,8 @@ async def add_episode_bulk(
----------
bulk_episodes : list[RawEpisode]
A list of RawEpisode objects to be processed and added to the graph.
group_id : str | None
An id for the graph partition the episode is a part of.
Returns
-------
Expand Down Expand Up @@ -463,6 +461,7 @@ async def add_episode_bulk(
source=episode.source,
content=episode.content,
source_description=episode.source_description,
group_id=group_id,
created_at=now,
valid_at=episode.reference_time,
)
Expand Down Expand Up @@ -527,7 +526,13 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
async def search(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
):
"""
Perform a hybrid search on the knowledge graph.
Expand All @@ -540,6 +545,8 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional
The maximum number of results to return. Defaults to 10.
Expand All @@ -562,6 +569,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
num_episodes=0,
num_edges=num_results,
num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
Expand Down Expand Up @@ -590,7 +598,10 @@ async def _search(
)

async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self,
query: str,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand All @@ -602,6 +613,8 @@ async def get_nodes_by_query(
----------
query : str
The text query to search for in the graph.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
The maximum number of results to return per search method.
If None, a default limit will be applied.
Expand All @@ -626,5 +639,7 @@ async def get_nodes_by_query(
"""
embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
)
return relevant_nodes
Loading

0 comments on commit 42fb590

Please sign in to comment.