Skip to content

Commit

Permalink
Add graphdb calls directly where relationships are filtered (microsof…
Browse files Browse the repository at this point in the history
…t#31)

* Add graphdb calls directly where relationships are filtered

* Add function to perform graphdb queries for relationships

---------

Co-authored-by: Guillermo Salvador Barrón Sánchez <[email protected]>
Co-authored-by: logomachic <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent ed14b6a commit 71263a4
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 35 deletions.
6 changes: 3 additions & 3 deletions common/graph_db_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def write_vertices(self,data: pd.DataFrame)->None:
"prop_partition_key": "entities",
"prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []),
"prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []),
"prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []),
"prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []),
},
)
time.sleep(5)


def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None:
def write_edges(self,data: pd.DataFrame)->None:
for row in data.itertuples():
if self.element_exists("g.E()",row.id):
continue
Expand All @@ -134,7 +134,7 @@ def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-
"prop_source_id": row.source,
"prop_target_id": row.target,
"prop_weight": row.weight,
"prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []),
"prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []),
"prop_description": row.description,
"prop_id": row.id,
"prop_human_readable_id": row.human_readable_id,
Expand Down
13 changes: 6 additions & 7 deletions graphrag/index/context_switch/contextSwitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,8 @@ def _read_config_parameters(root: str, config: str | None):
if not optimized_search:
final_covariates = read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")

if config.graphdb.enabled:
final_relationships = graph_db_client.query_edges(context_id)
final_entities = graph_db_client.query_vertices(context_id)
else:
final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")
final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")

final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")
final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")

vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
Expand All @@ -253,6 +248,10 @@ def _read_config_parameters(root: str, config: str | None):
description_embedding_store.load_entities(entities)
if self.use_kusto_community_reports:
description_embedding_store.load_reports(reports)

if config.graphdb.enabled:
graph_db_client.write_vertices(final_entities)
graph_db_client.write_edges(final_relationships)

def deactivate(self):
"""DeActivate the context."""
Expand Down
8 changes: 4 additions & 4 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,11 @@ def run_local_search(
final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")])

if config.graphdb.enabled:
final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)])
final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)])
else:
final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")])
final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")])


graph_db_client._client.close()
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
Expand Down Expand Up @@ -242,12 +240,14 @@ def run_local_search(
reports=reports,
text_units=read_indexer_text_units(final_text_units),
entities=entities,
relationships=read_indexer_relationships(final_relationships),
relationships=[],
covariates={"claims": covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
context_id=context_id,
is_optimized_search=optimized_search,
use_kusto_community_reports=use_kusto_community_reports,
graphdb_config=config.graphdb,
)

if optimized_search:
Expand Down
8 changes: 7 additions & 1 deletion graphrag/query/context_builder/local_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, cast

import pandas as pd
from common.graph_db_client import GraphDBClient
import tiktoken

from graphrag.model import Covariate, Entity, Relationship
Expand Down Expand Up @@ -164,14 +165,16 @@ def build_relationship_context(
relationship_ranking_attribute: str = "rank",
column_delimiter: str = "|",
context_name: str = "Relationships",
is_optimized_search: bool = False
is_optimized_search: bool = False,
graphdb_client: GraphDBClient|None=None,
) -> tuple[str, pd.DataFrame]:
"""Prepare relationship data tables as context data for system prompt."""
selected_relationships = _filter_relationships(
selected_entities=selected_entities,
relationships=relationships,
top_k_relationships=top_k_relationships,
relationship_ranking_attribute=relationship_ranking_attribute,
graphdb_client=graphdb_client,
)

if len(selected_entities) == 0 or len(selected_relationships) == 0:
Expand Down Expand Up @@ -236,13 +239,15 @@ def _filter_relationships(
relationships: list[Relationship],
top_k_relationships: int = 10,
relationship_ranking_attribute: str = "rank",
graphdb_client: GraphDBClient|None=None,
) -> list[Relationship]:
"""Filter and sort relationships based on a set of selected entities and a ranking attribute."""
# First priority: in-network relationships (i.e. relationships between selected entities)
in_network_relationships = get_in_network_relationships(
selected_entities=selected_entities,
relationships=relationships,
ranking_attribute=relationship_ranking_attribute,
graphdb_client=graphdb_client,
)

# Second priority - out-of-network relationships
Expand All @@ -251,6 +256,7 @@ def _filter_relationships(
selected_entities=selected_entities,
relationships=relationships,
ranking_attribute=relationship_ranking_attribute,
graphdb_client=graphdb_client,
)
if len(out_network_relationships) <= 1:
return in_network_relationships + out_network_relationships
Expand Down
5 changes: 5 additions & 0 deletions graphrag/query/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Query Factory methods to support CLI."""

from graphrag.config.models.graphdb_config import GraphDBConfig
import tiktoken
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

Expand Down Expand Up @@ -108,8 +109,10 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
context_id: str,
is_optimized_search: bool = False,
use_kusto_community_reports: bool = False,
graphdb_config: GraphDBConfig|None = None,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
Expand All @@ -132,6 +135,8 @@ def get_local_search_engine(
token_encoder=token_encoder,
is_optimized_search= is_optimized_search,
use_kusto_community_reports=use_kusto_community_reports,
graphdb_config=graphdb_config,
context_id=context_id,
),
token_encoder=token_encoder,
llm_params={
Expand Down
79 changes: 60 additions & 19 deletions graphrag/query/input/retrieval/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,48 @@

import pandas as pd

from common.graph_db_client import GraphDBClient
from graphrag.model import Entity, Relationship

from graphrag.query.input.loaders.dfs import read_relationships

def get_relationships_from_graphdb(query:str,selected_entity_names:list[str],graphdb_client: GraphDBClient):
relationships_result=graphdb_client._client.submit(
message=query,
bindings={
"prop_selected_entity_names": selected_entity_names,
}
)
return read_relationships(
graphdb_client.result_to_df(relationships_result),
short_id_col="human_readable_id"
)

def get_in_network_relationships(
selected_entities: list[Entity],
relationships: list[Relationship],
ranking_attribute: str = "rank",
graphdb_client: GraphDBClient|None=None,
) -> list[Relationship]:
"""Get all directed relationships between selected entities, sorted by ranking_attribute."""
selected_entity_names = [entity.title for entity in selected_entities]
selected_relationships = [
relationship
for relationship in relationships
if relationship.source in selected_entity_names
and relationship.target in selected_entity_names
]
if not graphdb_client:
selected_relationships = [
relationship
for relationship in relationships
if relationship.source in selected_entity_names
and relationship.target in selected_entity_names
]
else:
selected_relationships = get_relationships_from_graphdb(
query=(
"g.E()"
".where(inV().has('name',within(prop_selected_entity_names)))"
".where(outV().has('name',within(prop_selected_entity_names)))"
),
selected_entity_names=selected_entity_names,
graphdb_client=graphdb_client
)
if len(selected_relationships) <= 1:
return selected_relationships

Expand All @@ -36,22 +62,37 @@ def get_out_network_relationships(
selected_entities: list[Entity],
relationships: list[Relationship],
ranking_attribute: str = "rank",
graphdb_client: GraphDBClient|None=None,
) -> list[Relationship]:
"""Get relationships from selected entities to other entities that are not within the selected entities, sorted by ranking_attribute."""
selected_entity_names = [entity.title for entity in selected_entities]
source_relationships = [
relationship
for relationship in relationships
if relationship.source in selected_entity_names
and relationship.target not in selected_entity_names
]
target_relationships = [
relationship
for relationship in relationships
if relationship.target in selected_entity_names
and relationship.source not in selected_entity_names
]
selected_relationships = source_relationships + target_relationships
if not graphdb_client:
source_relationships = [
relationship
for relationship in relationships
if relationship.source in selected_entity_names
and relationship.target not in selected_entity_names
]
target_relationships = [
relationship
for relationship in relationships
if relationship.target in selected_entity_names
and relationship.source not in selected_entity_names
]
selected_relationships = source_relationships + target_relationships
else:
selected_relationships = get_relationships_from_graphdb(
query=(
"g.E().union("
"__.where(outV().has('name',without(prop_selected_entity_names)))"
".where(inV().has('name',within(prop_selected_entity_names))),"
"__.where(inV().has('name',without(prop_selected_entity_names)))"
".where(outV().has('name',within(prop_selected_entity_names)))"
")"
),
selected_entity_names= selected_entity_names,
graphdb_client=graphdb_client
)
return sort_relationships_by_ranking_attribute(
selected_relationships, selected_entities, ranking_attribute
)
Expand Down
12 changes: 11 additions & 1 deletion graphrag/query/structured_search/local_search/mixed_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Any

import pandas as pd
from common.graph_db_client import GraphDBClient
from graphrag.config.models.graphdb_config import GraphDBConfig
import tiktoken

from graphrag.model import (
Expand Down Expand Up @@ -64,6 +66,8 @@ def __init__(
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
is_optimized_search: bool = False,
use_kusto_community_reports: bool = False,
graphdb_config: GraphDBConfig|None = None,
context_id:str = None,
):
if community_reports is None:
community_reports = []
Expand All @@ -88,6 +92,8 @@ def __init__(
self.embedding_vectorstore_key = embedding_vectorstore_key
self.is_optimized_search = is_optimized_search
self.use_kusto_community_reports = use_kusto_community_reports
self.graphdb_config = graphdb_config
self.context_id = context_id

def filter_by_entity_keys(self, entity_keys: list[int] | list[str]):
"""Filter entity text embeddings by entity keys."""
Expand Down Expand Up @@ -433,6 +439,7 @@ def _build_local_context(
final_context_data = {}

# gradually add entities and associated metadata to the context until we reach limit
graphdb_client=GraphDBClient(self.graphdb_config,self.context_id) if (self.graphdb_config and self.graphdb_config.enabled) else None
for entity in selected_entities:
current_context = []
current_context_data = {}
Expand All @@ -452,7 +459,8 @@ def _build_local_context(
include_relationship_weight=include_relationship_weight,
relationship_ranking_attribute=relationship_ranking_attribute,
context_name="Relationships",
is_optimized_search=is_optimized_search
is_optimized_search=is_optimized_search,
graphdb_client=graphdb_client,
)
current_context.append(relationship_context)
current_context_data["relationships"] = relationship_context_data
Expand Down Expand Up @@ -484,6 +492,8 @@ def _build_local_context(
final_context_data = current_context_data

# attach entity context to final context
if graphdb_client:
graphdb_client._client.close()
final_context_text = entity_context + "\n\n" + "\n\n".join(final_context)
final_context_data["entities"] = entity_context_data

Expand Down

0 comments on commit 71263a4

Please sign in to comment.