diff --git a/.semversioner/next-release/patch-20241218221915558063.json b/.semversioner/next-release/patch-20241218221915558063.json new file mode 100644 index 0000000000..206754635b --- /dev/null +++ b/.semversioner/next-release/patch-20241218221915558063.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction" +} diff --git a/graphrag/index/flows/extract_graph.py b/graphrag/index/flows/extract_graph.py index f274d55f64..db73635b85 100644 --- a/graphrag/index/flows/extract_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -52,6 +52,18 @@ async def extract_graph( num_threads=extraction_num_threads, ) + if not _validate_data(entity_dfs): + error_msg = "Entity Extraction failed. No entities detected during extraction." + callbacks.error(error_msg) + raise ValueError(error_msg) + + if not _validate_data(relationship_dfs): + error_msg = ( + "Entity Extraction failed. No relationships detected during extraction." + ) + callbacks.error(error_msg) + raise ValueError(error_msg) + merged_entities = _merge_entities(entity_dfs) merged_relationships = _merge_relationships(relationship_dfs) @@ -145,3 +157,10 @@ def _compute_degree(graph: nx.Graph) -> pd.DataFrame: {"name": node, "degree": int(degree)} for node, degree in graph.degree # type: ignore ]) + + +def _validate_data(df_list: list[pd.DataFrame]) -> bool: + """Validate that the dataframe list is valid. At least one dataframe must contain data.""" + return any( + len(df) > 0 for df in df_list + ) # Check for len, not .empty, as the dfs have schemas in some cases diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index 07b774c434..687e0f7e06 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -24,6 +24,7 @@ import graphrag.config.defaults as defs from graphrag.config.enums import LLMType from graphrag.config.models.llm_parameters import LLMParameters +from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton from .mock_llm import MockChatLLM @@ -110,6 +111,10 @@ def load_llm( chat_only=False, ) -> ChatLLM: """Load the LLM for the entity extraction chain.""" + singleton_llm = ChatLLMSingleton().get_llm(name) + if singleton_llm is not None: + return singleton_llm + on_error = _create_error_handler(callbacks) llm_type = config.type @@ -119,7 +124,9 @@ def load_llm( raise ValueError(msg) loader = loaders[llm_type] - return loader["load"](on_error, create_cache(cache, name), config) + llm_instance = loader["load"](on_error, create_cache(cache, name), config) + ChatLLMSingleton().set_llm(name, llm_instance) + return llm_instance msg = f"Unknown LLM type {llm_type}" raise ValueError(msg) @@ -134,15 +141,21 @@ def load_llm_embeddings( chat_only=False, ) -> EmbeddingsLLM: """Load the LLM for the entity extraction chain.""" + singleton_llm = EmbeddingsLLMSingleton().get_llm(name) + if singleton_llm is not None: + return singleton_llm + on_error = _create_error_handler(callbacks) llm_type = llm_config.type if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" raise ValueError(msg) - return loaders[llm_type]["load"]( + llm_instance = loaders[llm_type]["load"]( on_error, create_cache(cache, name), llm_config or {} ) + EmbeddingsLLMSingleton().set_llm(name, llm_instance) + return llm_instance msg = f"Unknown LLM type {llm_type}" raise ValueError(msg) @@ -198,6 +211,7 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig: n=config.n, temperature=config.temperature, ) + if azure: if config.api_base is None: msg = "Azure OpenAI Chat LLM requires an API base" diff --git a/graphrag/index/llm/manager.py b/graphrag/index/llm/manager.py new file mode 100644 index 0000000000..2e35cd89c4 --- /dev/null +++ b/graphrag/index/llm/manager.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Manager singleton.""" + +from functools import cache + +from fnllm import ChatLLM, EmbeddingsLLM + + +@cache +class ChatLLMSingleton: + """A singleton class for the chat LLM instances.""" + + def __init__(self): + self.llm_dict = {} + + def set_llm(self, name, llm): + """Add an LLM to the dictionary.""" + self.llm_dict[name] = llm + + def get_llm(self, name) -> ChatLLM | None: + """Get an LLM from the dictionary.""" + return self.llm_dict.get(name) + + +@cache +class EmbeddingsLLMSingleton: + """A singleton class for the embeddings LLM instances.""" + + def __init__(self): + self.llm_dict = {} + + def set_llm(self, name, llm): + """Add an LLM to the dictionary.""" + self.llm_dict[name] = llm + + def get_llm(self, name) -> EmbeddingsLLM | None: + """Get an LLM from the dictionary.""" + return self.llm_dict.get(name)