Skip to content

Commit

Permalink
Make default DB explicit (#195)
Browse files Browse the repository at this point in the history
* add default database

* update

* init tests

* update test

* bump version

* removed unused imports
  • Loading branch information
prasmussen15 authored Oct 21, 2024
1 parent 8b72250 commit b217d1e
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 58 deletions.
43 changes: 22 additions & 21 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@

from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
EPISODIC_EDGE_SAVE,
)
from graphiti_core.nodes import Node

logger = logging.getLogger(__name__)
Expand All @@ -49,6 +54,7 @@ async def delete(self, driver: AsyncDriver):
DELETE e
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Deleted Edge: {self.uuid}')
Expand All @@ -70,17 +76,13 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
EPISODIC_EDGE_SAVE,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -100,6 +102,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand All @@ -122,6 +125,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand All @@ -144,6 +148,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -184,14 +189,7 @@ async def generate_embedding(self, embedder: EmbedderClient):

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid""",
ENTITY_EDGE_SAVE,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
Expand All @@ -204,6 +202,7 @@ async def save(self, driver: AsyncDriver):
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -230,6 +229,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.invalid_at AS invalid_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -259,6 +259,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.invalid_at AS invalid_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -288,6 +289,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.invalid_at AS invalid_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand All @@ -300,17 +302,13 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
COMMUNITY_EDGE_SAVE,
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -330,6 +328,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand All @@ -350,6 +349,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand All @@ -370,6 +370,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand Down
3 changes: 3 additions & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
limitations under the License.
"""

import os
from datetime import datetime

import numpy as np
from neo4j import time as neo4j_time

DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)


def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
Expand Down
Empty file.
Empty file.
22 changes: 22 additions & 0 deletions graphiti_core/models/edges/edge_db_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""

ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""

COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""
Empty file.
17 changes: 17 additions & 0 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""

ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
37 changes: 22 additions & 15 deletions graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@

from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
EPISODIC_NODE_SAVE,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,6 +90,7 @@ async def delete(self, driver: AsyncDriver):
DETACH DELETE n
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Deleted Node: {self.uuid}')
Expand Down Expand Up @@ -119,11 +126,7 @@ class EpisodicNode(Node):

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""",
EPISODIC_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
Expand All @@ -133,6 +136,7 @@ async def save(self, driver: AsyncDriver):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand All @@ -154,6 +158,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.source AS source
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand All @@ -179,6 +184,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.source AS source
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand All @@ -201,6 +207,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.source AS source
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand All @@ -223,17 +230,14 @@ async def generate_name_embedding(self, embedder: EmbedderClient):

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
ENTITY_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand All @@ -254,6 +258,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -277,6 +282,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -297,6 +303,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -310,17 +317,14 @@ class CommunityNode(Node):

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
COMMUNITY_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand Down Expand Up @@ -350,6 +354,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)

nodes = [get_community_node_from_record(record) for record in records]
Expand All @@ -373,6 +378,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)

communities = [get_community_node_from_record(record) for record in records]
Expand All @@ -393,6 +399,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)

communities = [get_community_node_from_record(record) for record in records]
Expand Down
Loading

0 comments on commit b217d1e

Please sign in to comment.