diff --git a/.semversioner/next-release/patch-20241014040518441266.json b/.semversioner/next-release/patch-20241014040518441266.json new file mode 100644 index 0000000000..c5831a0c30 --- /dev/null +++ b/.semversioner/next-release/patch-20241014040518441266.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Perf optimizations in map_query_to_entities()" +} diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 82a0699cd8..4b8767b87d 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -7,6 +7,7 @@ from graphrag.model import Entity, Relationship from graphrag.query.input.retrieval.entities import ( + get_entity_by_id, get_entity_by_key, get_entity_by_name, ) @@ -36,7 +37,7 @@ def map_query_to_entities( query: str, text_embedding_vectorstore: BaseVectorStore, text_embedder: BaseTextEmbedding, - all_entities: list[Entity], + all_entities_dict: dict[str, Entity], embedding_vectorstore_key: str = EntityVectorStoreKey.ID, include_entity_names: list[str] | None = None, exclude_entity_names: list[str] | None = None, @@ -48,6 +49,7 @@ def map_query_to_entities( include_entity_names = [] if exclude_entity_names is None: exclude_entity_names = [] + all_entities = list(all_entities_dict.values()) matched_entities = [] if query != "": # get entities with highest semantic similarity to query @@ -58,11 +60,16 @@ def map_query_to_entities( k=k * oversample_scaler, ) for result in search_results: - matched = get_entity_by_key( - entities=all_entities, - key=embedding_vectorstore_key, - value=result.document.id, - ) + if embedding_vectorstore_key == EntityVectorStoreKey.ID and isinstance( + result.document.id, str + ): + matched = get_entity_by_id(all_entities_dict, result.document.id) + else: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) if matched: matched_entities.append(matched) else: diff --git a/graphrag/query/input/retrieval/entities.py b/graphrag/query/input/retrieval/entities.py index 5465f9f59e..41c92fab31 100644 --- a/graphrag/query/input/retrieval/entities.py +++ b/graphrag/query/input/retrieval/entities.py @@ -12,17 +12,26 @@ from graphrag.model import Entity +def get_entity_by_id(entities: dict[str, Entity], value: str) -> Entity | None: + """Get entity by id.""" + entity = entities.get(value) + if entity is None and is_valid_uuid(value): + entity = entities.get(value.replace("-", "")) + return entity + + def get_entity_by_key( entities: Iterable[Entity], key: str, value: str | int ) -> Entity | None: """Get entity by key.""" - for entity in entities: - if isinstance(value, str) and is_valid_uuid(value): - if getattr(entity, key) == value or getattr(entity, key) == value.replace( - "-", "" - ): + if isinstance(value, str) and is_valid_uuid(value): + value_no_dashes = value.replace("-", "") + for entity in entities: + entity_value = getattr(entity, key) + if entity_value in (value, value_no_dashes): return entity - else: + else: + for entity in entities: if getattr(entity, key) == value: return entity return None diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index e0608e4bd2..d160fe8190 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -141,7 +141,7 @@ def build_context( query=query, text_embedding_vectorstore=self.entity_text_embeddings, text_embedder=self.text_embedder, - all_entities=list(self.entities.values()), + all_entities_dict=self.entities, embedding_vectorstore_key=self.embedding_vectorstore_key, include_entity_names=include_entity_names, exclude_entity_names=exclude_entity_names, diff --git a/tests/unit/query/context_builder/__init__.py b/tests/unit/query/context_builder/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/query/context_builder/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py new file mode 100644 index 0000000000..de71b8806d --- /dev/null +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from typing import Any + +from graphrag.model import Entity +from graphrag.model.types import TextEmbedder +from graphrag.query.context_builder.entity_extraction import ( + EntityVectorStoreKey, + map_query_to_entities, +) +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.vector_stores import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class MockBaseVectorStore(BaseVectorStore): + def __init__(self, documents: list[VectorStoreDocument]) -> None: + super().__init__("mock") + self.documents = documents + + def connect(self, **kwargs: Any) -> None: + raise NotImplementedError + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + raise NotImplementedError + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + return [ + VectorStoreSearchResult(document=document, score=1) + for document in self.documents[:k] + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + return sorted( + [ + VectorStoreSearchResult( + document=document, score=abs(len(text) - len(document.text or "")) + ) + for document in self.documents + ], + key=lambda x: x.score, + )[:k] + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + return [document for document in self.documents if document.id in include_ids] + + +class MockBaseTextEmbedding(BaseTextEmbedding): + def embed(self, text: str, **kwargs: Any) -> list[float]: + return [len(text)] + + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + return [len(text)] + + +def test_map_query_to_entities(): + entities = [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="t1", + rank=2, + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="t333", + rank=1, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] + + assert map_query_to_entities( + query="t22", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + k=1, + oversample_scaler=1, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ) + ] + + assert map_query_to_entities( + query="t22", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.title, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.TITLE, + k=1, + oversample_scaler=1, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ) + ] + + assert map_query_to_entities( + query="", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + k=2, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] + + assert map_query_to_entities( + query="", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.TITLE, + k=2, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] diff --git a/tests/unit/query/input/__init__.py b/tests/unit/query/input/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/query/input/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/input/retrieval/__init__.py b/tests/unit/query/input/retrieval/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/query/input/retrieval/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/input/retrieval/test_entities.py b/tests/unit/query/input/retrieval/test_entities.py new file mode 100644 index 0000000000..a66e3432b9 --- /dev/null +++ b/tests/unit/query/input/retrieval/test_entities.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.model import Entity +from graphrag.query.input.retrieval.entities import ( + get_entity_by_id, + get_entity_by_key, +) + + +def test_get_entity_by_id(): + assert ( + get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + ] + }, + "00000000-0000-0000-0000-000000000000", + ) + is None + ) + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ] + }, + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3" + ) + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a50a844d4aa2cfd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f9356445074ee4b10298add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc947c9445393a3d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ] + }, + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3") + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity(id="id1", short_id="sid1", title="title1"), + Entity(id="id2", short_id="sid2", title="title2"), + Entity(id="id3", short_id="sid3", title="title3"), + ] + }, + "id3", + ) == Entity(id="id3", short_id="sid3", title="title3") + + +def test_get_entity_by_key(): + assert ( + get_entity_by_key( + [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + ], + "id", + "00000000-0000-0000-0000-000000000000", + ) + is None + ) + + assert get_entity_by_key( + [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ], + "id", + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3" + ) + + assert get_entity_by_key( + [ + Entity( + id="2da37c7a50a844d4aa2cfd401e19976c", short_id="sid1", title="title1" + ), + Entity( + id="c4f9356445074ee4b10298add401a965", short_id="sid2", title="title2" + ), + Entity( + id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3" + ), + ], + "id", + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3") + + assert get_entity_by_key( + [ + Entity(id="id1", short_id="sid1", title="title1"), + Entity(id="id2", short_id="sid2", title="title2"), + Entity(id="id3", short_id="sid3", title="title3"), + ], + "id", + "id3", + ) == Entity(id="id3", short_id="sid3", title="title3") + + assert get_entity_by_key( + [ + Entity(id="id1", short_id="sid1", title="title1", rank=1), + Entity(id="id2", short_id="sid2", title="title2a", rank=2), + Entity(id="id3", short_id="sid3", title="title3", rank=3), + Entity(id="id2", short_id="sid2", title="title2b", rank=2), + ], + "rank", + 2, + ) == Entity(id="id2", short_id="sid2", title="title2a", rank=2)