From 5ff2d3c76d2d08797d3337bbe5278c165324f441 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Thu, 5 Dec 2024 16:07:47 -0800 Subject: [PATCH] Remove graphrag.llm, replace with fnllm (#1315) * add fnllm; remove llm folder * remove llm unit tests * update imports * update imports * formatting * enable autosave * update mockllm * update community reports extractor * move most llm usage to fnllm * update type issues * fix unit tests * type updates * update dictionary * semver * update llm construction, get integration tests working * load from llmparameters model * move ruff settings to ruff.toml * add gitattributes file * ignore ruff.toml spelling * update .gitattributes * update gitignore * update config construction * update prompt var usage * add cache adapter * use cache adapter in embeddings calls * update embedding strategy * add fnllm * add pytest-dotenv * fix some verb tests * get verbtests running * update ruff.toml for vscode * enable ruff native server in vscode * update artifact inspecting code * remove local-test update * use string.replace instead of string.format in community reprots etxractor * bump timeout * revert ruff.toml, vscode settings for another pr * revert cspell config * revert gitignore * remove json-repair, update fnllm * use fnllm generic type interfaces * update load_llm to use target models * consolidate chat parameters * add 'extra_attributes' prop to community report response * formatting * update fnllm * formatting * formatting * Add defaults to some llm params to avoid null on params hash * Formatting --------- Co-authored-by: Alonso Guevara Co-authored-by: Josh Bradley --- .../patch-20241024210728482023.json | 4 + dictionary.txt | 1 + graphrag/api/prompt_tune.py | 7 +- graphrag/config/create_graphrag_config.py | 4 +- graphrag/config/defaults.py | 3 + graphrag/config/enums.py | 4 - graphrag/config/models/llm_parameters.py | 17 + .../extractors/claims/claim_extractor.py | 24 +- .../community_reports_extractor.py | 89 +++-- .../graph/extractors/graph/graph_extractor.py | 21 +- .../description_summary_extractor.py | 16 +- graphrag/index/llm/load_llm.py | 338 +++++++++--------- graphrag/index/llm/mock_llm.py | 45 +++ .../embed_text/strategies/openai.py | 31 +- .../extract_covariates/strategies.py | 11 +- .../strategies/graph_intelligence.py | 11 +- .../summarize_communities/strategies.py | 39 +- .../summarize_descriptions/strategies.py | 13 +- graphrag/index/validate_config.py | 14 +- graphrag/llm/__init__.py | 91 ----- graphrag/llm/base/__init__.py | 10 - graphrag/llm/base/_create_cache_key.py | 43 --- graphrag/llm/base/base_llm.py | 70 ---- graphrag/llm/base/caching_llm.py | 109 ------ graphrag/llm/base/rate_limiting_llm.py | 208 ----------- graphrag/llm/errors.py | 12 - graphrag/llm/limiting/__init__.py | 18 - graphrag/llm/limiting/composite_limiter.py | 26 -- graphrag/llm/limiting/create_limiters.py | 29 -- graphrag/llm/limiting/llm_limiter.py | 19 - graphrag/llm/limiting/noop_llm_limiter.py | 19 - graphrag/llm/limiting/tpm_rpm_limiter.py | 34 -- graphrag/llm/mock/__init__.py | 12 - graphrag/llm/mock/mock_chat_llm.py | 52 --- graphrag/llm/mock/mock_completion_llm.py | 42 --- graphrag/llm/openai/__init__.py | 28 -- graphrag/llm/openai/_prompts.py | 39 -- graphrag/llm/openai/create_openai_client.py | 66 ---- graphrag/llm/openai/factories.py | 140 -------- graphrag/llm/openai/json_parsing_llm.py | 38 -- graphrag/llm/openai/openai_chat_llm.py | 150 -------- graphrag/llm/openai/openai_completion_llm.py | 43 --- graphrag/llm/openai/openai_configuration.py | 288 --------------- graphrag/llm/openai/openai_embeddings_llm.py | 40 --- .../llm/openai/openai_history_tracking_llm.py | 42 --- .../llm/openai/openai_token_replacing_llm.py | 37 -- graphrag/llm/openai/types.py | 11 - graphrag/llm/openai/utils.py | 160 --------- graphrag/llm/types/__init__.py | 46 --- graphrag/llm/types/llm.py | 28 -- graphrag/llm/types/llm_cache.py | 22 -- graphrag/llm/types/llm_callbacks.py | 20 -- graphrag/llm/types/llm_config.py | 35 -- graphrag/llm/types/llm_invocation_result.py | 35 -- graphrag/llm/types/llm_io.py | 50 --- graphrag/llm/types/llm_types.py | 16 - .../generator/community_report_rating.py | 7 +- .../generator/community_reporter_role.py | 7 +- graphrag/prompt_tune/generator/domain.py | 7 +- .../generator/entity_relationship.py | 9 +- .../prompt_tune/generator/entity_types.py | 23 +- graphrag/prompt_tune/generator/language.py | 7 +- graphrag/prompt_tune/generator/persona.py | 9 +- graphrag/prompt_tune/loader/input.py | 17 +- .../query/context_builder/rate_relevancy.py | 3 +- graphrag/query/llm/get_client.py | 5 +- graphrag/query/llm/text_utils.py | 64 ++++ .../structured_search/global_search/search.py | 11 +- poetry.lock | 259 ++++++++------ pyproject.toml | 10 +- tests/unit/indexing/verbs/helpers/mock_llm.py | 9 +- tests/unit/llm/__init__.py | 2 - tests/unit/llm/base/__init__.py | 2 - tests/unit/llm/base/test_caching_llm.py | 73 ---- tests/unit/llm/openai/__init__.py | 2 - .../llm/openai/test_history_tracking_llm.py | 32 -- .../test_create_final_community_reports.py | 39 +- 77 files changed, 670 insertions(+), 2747 deletions(-) create mode 100644 .semversioner/next-release/patch-20241024210728482023.json create mode 100644 graphrag/index/llm/mock_llm.py delete mode 100644 graphrag/llm/__init__.py delete mode 100644 graphrag/llm/base/__init__.py delete mode 100644 graphrag/llm/base/_create_cache_key.py delete mode 100644 graphrag/llm/base/base_llm.py delete mode 100644 graphrag/llm/base/caching_llm.py delete mode 100644 graphrag/llm/base/rate_limiting_llm.py delete mode 100644 graphrag/llm/errors.py delete mode 100644 graphrag/llm/limiting/__init__.py delete mode 100644 graphrag/llm/limiting/composite_limiter.py delete mode 100644 graphrag/llm/limiting/create_limiters.py delete mode 100644 graphrag/llm/limiting/llm_limiter.py delete mode 100644 graphrag/llm/limiting/noop_llm_limiter.py delete mode 100644 graphrag/llm/limiting/tpm_rpm_limiter.py delete mode 100644 graphrag/llm/mock/__init__.py delete mode 100644 graphrag/llm/mock/mock_chat_llm.py delete mode 100644 graphrag/llm/mock/mock_completion_llm.py delete mode 100644 graphrag/llm/openai/__init__.py delete mode 100644 graphrag/llm/openai/_prompts.py delete mode 100644 graphrag/llm/openai/create_openai_client.py delete mode 100644 graphrag/llm/openai/factories.py delete mode 100644 graphrag/llm/openai/json_parsing_llm.py delete mode 100644 graphrag/llm/openai/openai_chat_llm.py delete mode 100644 graphrag/llm/openai/openai_completion_llm.py delete mode 100644 graphrag/llm/openai/openai_configuration.py delete mode 100644 graphrag/llm/openai/openai_embeddings_llm.py delete mode 100644 graphrag/llm/openai/openai_history_tracking_llm.py delete mode 100644 graphrag/llm/openai/openai_token_replacing_llm.py delete mode 100644 graphrag/llm/openai/types.py delete mode 100644 graphrag/llm/openai/utils.py delete mode 100644 graphrag/llm/types/__init__.py delete mode 100644 graphrag/llm/types/llm.py delete mode 100644 graphrag/llm/types/llm_cache.py delete mode 100644 graphrag/llm/types/llm_callbacks.py delete mode 100644 graphrag/llm/types/llm_config.py delete mode 100644 graphrag/llm/types/llm_invocation_result.py delete mode 100644 graphrag/llm/types/llm_io.py delete mode 100644 graphrag/llm/types/llm_types.py delete mode 100644 tests/unit/llm/__init__.py delete mode 100644 tests/unit/llm/base/__init__.py delete mode 100644 tests/unit/llm/base/test_caching_llm.py delete mode 100644 tests/unit/llm/openai/__init__.py delete mode 100644 tests/unit/llm/openai/test_history_tracking_llm.py diff --git a/.semversioner/next-release/patch-20241024210728482023.json b/.semversioner/next-release/patch-20241024210728482023.json new file mode 100644 index 0000000000..48a498bbc6 --- /dev/null +++ b/.semversioner/next-release/patch-20241024210728482023.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "replace llm package with fnllm" +} diff --git a/dictionary.txt b/dictionary.txt index e2ea99f021..0a18cd9272 100644 --- a/dictionary.txt +++ b/dictionary.txt @@ -68,6 +68,7 @@ pypi nbformat semversioner mkdocs +fnllm typer # Library Methods diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 917727214c..7945704a0a 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -97,10 +97,9 @@ async def generate_indexing_prompts( # Create LLM from config llm = load_llm( "prompt_tuning", - config.llm.type, - NoopVerbCallbacks(), - None, - config.llm.model_dump(), + config.llm, + cache=None, + callbacks=NoopVerbCallbacks(), ) if not domain: diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 9f9c239f42..53ce3f9f3a 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -702,9 +702,7 @@ class Section(str, Enum): def _is_azure(llm_type: LLMType | None) -> bool: return ( - llm_type == LLMType.AzureOpenAI - or llm_type == LLMType.AzureOpenAIChat - or llm_type == LLMType.AzureOpenAIEmbedding + llm_type == LLMType.AzureOpenAIChat or llm_type == LLMType.AzureOpenAIEmbedding ) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 9fbb0ea899..e81a3758a0 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -20,9 +20,11 @@ ASYNC_MODE = AsyncType.Threaded ENCODING_MODEL = "cl100k_base" +AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default" # # LLM Parameters # +LLM_FREQUENCY_PENALTY = 0.0 LLM_TYPE = LLMType.OpenAIChat LLM_MODEL = "gpt-4-turbo-preview" LLM_MAX_TOKENS = 4000 @@ -34,6 +36,7 @@ LLM_REQUESTS_PER_MINUTE = 0 LLM_MAX_RETRIES = 10 LLM_MAX_RETRY_WAIT = 10.0 +LLM_PRESENCE_PENALTY = 0.0 LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True LLM_CONCURRENT_REQUESTS = 25 diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 99d385dff0..4e27833395 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -100,10 +100,6 @@ class LLMType(str, Enum): OpenAIEmbedding = "openai_embedding" AzureOpenAIEmbedding = "azure_openai_embedding" - # Raw Completion - OpenAI = "openai" - AzureOpenAI = "azure_openai" - # Chat Completion OpenAIChat = "openai_chat" AzureOpenAIChat = "azure_openai_chat" diff --git a/graphrag/config/models/llm_parameters.py b/graphrag/config/models/llm_parameters.py index 4f18ded06f..300498ae76 100644 --- a/graphrag/config/models/llm_parameters.py +++ b/graphrag/config/models/llm_parameters.py @@ -20,7 +20,13 @@ class LLMParameters(BaseModel): type: LLMType = Field( description="The type of LLM model to use.", default=defs.LLM_TYPE ) + encoding_model: str | None = Field( + description="The encoding model to use", default=defs.ENCODING_MODEL + ) model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL) + embeddings_model: str | None = Field( + description="The embeddings model to use.", default=defs.EMBEDDING_MODEL + ) max_tokens: int | None = Field( description="The maximum number of tokens to generate.", default=defs.LLM_MAX_TOKENS, @@ -37,6 +43,14 @@ class LLMParameters(BaseModel): description="The number of completions to generate.", default=defs.LLM_N, ) + frequency_penalty: float | None = Field( + description="The frequency penalty to use for token generation.", + default=defs.LLM_FREQUENCY_PENALTY, + ) + presence_penalty: float | None = Field( + description="The presence penalty to use for token generation.", + default=defs.LLM_PRESENCE_PENALTY, + ) request_timeout: float = Field( description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT ) @@ -86,3 +100,6 @@ class LLMParameters(BaseModel): description="Whether to use concurrent requests for the LLM service.", default=defs.LLM_CONCURRENT_REQUESTS, ) + responses: list[str | BaseModel] | None = Field( + default=None, description="Static responses to use in mock mode." + ) diff --git a/graphrag/index/graph/extractors/claims/claim_extractor.py b/graphrag/index/graph/extractors/claims/claim_extractor.py index 2842ad7e1a..66162f8f12 100644 --- a/graphrag/index/graph/extractors/claims/claim_extractor.py +++ b/graphrag/index/graph/extractors/claims/claim_extractor.py @@ -9,10 +9,10 @@ from typing import Any import tiktoken +from fnllm import ChatLLM import graphrag.config.defaults as defs from graphrag.index.typing import ErrorHandlerFn -from graphrag.llm import CompletionLLM from graphrag.prompts.index.claim_extraction import ( CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, @@ -36,7 +36,7 @@ class ClaimExtractorResult: class ClaimExtractor: """Claim extractor class definition.""" - _llm: CompletionLLM + _llm: ChatLLM _extraction_prompt: str _summary_prompt: str _output_formatter_prompt: str @@ -48,10 +48,11 @@ class ClaimExtractor: _completion_delimiter_key: str _max_gleanings: int _on_error: ErrorHandlerFn + _loop_args: dict[str, Any] def __init__( self, - llm_invoker: CompletionLLM, + llm_invoker: ChatLLM, extraction_prompt: str | None = None, input_text_key: str | None = None, input_entity_spec_key: str | None = None, @@ -87,9 +88,9 @@ def __init__( # Construct the looping arguments encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") - yes = encoding.encode("YES") - no = encoding.encode("NO") - self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + yes = f"{encoding.encode('YES')[0]}" + no = f"{encoding.encode('NO')[0]}" + self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} async def __call__( self, inputs: dict[str, Any], prompt_variables: dict | None = None @@ -164,13 +165,12 @@ async def _process_document( ) response = await self._llm( - self._extraction_prompt, - variables={ + self._extraction_prompt.format(**{ self._input_text_key: doc, **prompt_args, - }, + }) ) - results = response.output or "" + results = response.output.content or "" claims = results.strip().removesuffix(completion_delimiter) # Repeat to ensure we maximize entity count @@ -180,7 +180,7 @@ async def _process_document( name=f"extract-continuation-{i}", history=response.history, ) - extension = response.output or "" + extension = response.output.content or "" claims += record_delimiter + extension.strip().removesuffix( completion_delimiter ) @@ -195,7 +195,7 @@ async def _process_document( history=response.history, model_parameters=self._loop_args, ) - if response.output != "YES": + if response.output.content != "YES": break return self._parse_claim_tuples(results, prompt_args) diff --git a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py index a78064bd9b..7fa0b684fd 100644 --- a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py +++ b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py @@ -8,26 +8,50 @@ from dataclasses import dataclass from typing import Any +from fnllm import ChatLLM +from pydantic import BaseModel, Field + from graphrag.index.typing import ErrorHandlerFn -from graphrag.index.utils.dicts import dict_has_keys_with_types -from graphrag.llm import CompletionLLM from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT log = logging.getLogger(__name__) +class FindingModel(BaseModel): + """A model for the expected LLM response shape.""" + + summary: str = Field(description="The summary of the finding.") + explanation: str = Field(description="An explanation of the finding.") + + +class CommunityReportResponse(BaseModel): + """A model for the expected LLM response shape.""" + + title: str = Field(description="The title of the report.") + summary: str = Field(description="A summary of the report.") + findings: list[FindingModel] = Field( + description="A list of findings in the report." + ) + rating: float = Field(description="The rating of the report.") + rating_explanation: str = Field(description="An explanation of the rating.") + + extra_attributes: dict[str, Any] = Field( + default_factory=dict, description="Extra attributes." + ) + + @dataclass class CommunityReportsResult: """Community reports result class definition.""" output: str - structured_output: dict + structured_output: CommunityReportResponse | None class CommunityReportsExtractor: """Community reports extractor class definition.""" - _llm: CompletionLLM + _llm: ChatLLM _input_text_key: str _extraction_prompt: str _output_formatter_prompt: str @@ -36,7 +60,7 @@ class CommunityReportsExtractor: def __init__( self, - llm_invoker: CompletionLLM, + llm_invoker: ChatLLM, input_text_key: str | None = None, extraction_prompt: str | None = None, on_error: ErrorHandlerFn | None = None, @@ -53,55 +77,30 @@ async def __call__(self, inputs: dict[str, Any]): """Call method definition.""" output = None try: - response = ( - await self._llm( - self._extraction_prompt, - json=True, - name="create_community_report", - variables={self._input_text_key: inputs[self._input_text_key]}, - is_response_valid=lambda x: dict_has_keys_with_types( - x, - [ - ("title", str), - ("summary", str), - ("findings", list), - ("rating", float), - ("rating_explanation", str), - ], - inplace=True, - ), - model_parameters={"max_tokens": self._max_report_length}, - ) - or {} + input_text = inputs[self._input_text_key] + prompt = self._extraction_prompt.replace( + "{" + self._input_text_key + "}", input_text + ) + response = await self._llm( + prompt, + json=True, + name="create_community_report", + json_model=CommunityReportResponse, + model_parameters={"max_tokens": self._max_report_length}, ) - output = response.json or {} + output = response.parsed_json except Exception as e: log.exception("error generating community report") self._on_error(e, traceback.format_exc(), None) - output = {} - text_output = self._get_text_output(output) + text_output = self._get_text_output(output) if output else "" return CommunityReportsResult( structured_output=output, output=text_output, ) - def _get_text_output(self, parsed_output: dict) -> str: - title = parsed_output.get("title", "Report") - summary = parsed_output.get("summary", "") - findings = parsed_output.get("findings", []) - - def finding_summary(finding: dict): - if isinstance(finding, str): - return finding - return finding.get("summary") - - def finding_explanation(finding: dict): - if isinstance(finding, str): - return "" - return finding.get("explanation") - + def _get_text_output(self, report: CommunityReportResponse) -> str: report_sections = "\n\n".join( - f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings + f"## {f.summary}\n\n{f.explanation}" for f in report.findings ) - return f"# {title}\n\n{summary}\n\n{report_sections}" + return f"# {report.title}\n\n{report.summary}\n\n{report_sections}" diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index 7374e77c24..fb28f528c9 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -12,11 +12,11 @@ import networkx as nx import tiktoken +from fnllm import ChatLLM import graphrag.config.defaults as defs from graphrag.index.typing import ErrorHandlerFn from graphrag.index.utils.string import clean_str -from graphrag.llm import CompletionLLM from graphrag.prompts.index.entity_extraction import ( CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, @@ -40,7 +40,7 @@ class GraphExtractionResult: class GraphExtractor: """Unipartite graph extractor class definition.""" - _llm: CompletionLLM + _llm: ChatLLM _join_descriptions: bool _tuple_delimiter_key: str _record_delimiter_key: str @@ -57,7 +57,7 @@ class GraphExtractor: def __init__( self, - llm_invoker: CompletionLLM, + llm_invoker: ChatLLM, tuple_delimiter_key: str | None = None, record_delimiter_key: str | None = None, input_text_key: str | None = None, @@ -90,9 +90,9 @@ def __init__( # Construct the looping arguments encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") - yes = encoding.encode("YES") - no = encoding.encode("NO") - self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + yes = f"{encoding.encode('YES')[0]}" + no = f"{encoding.encode('NO')[0]}" + self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} async def __call__( self, texts: list[str], prompt_variables: dict[str, Any] | None = None @@ -151,13 +151,12 @@ async def _process_document( self, text: str, prompt_variables: dict[str, str] ) -> str: response = await self._llm( - self._extraction_prompt, - variables={ + self._extraction_prompt.format(**{ **prompt_variables, self._input_text_key: text, - }, + }), ) - results = response.output or "" + results = response.output.content or "" # Repeat to ensure we maximize entity count for i in range(self._max_gleanings): @@ -166,7 +165,7 @@ async def _process_document( name=f"extract-continuation-{i}", history=response.history, ) - results += response.output or "" + results += response.output.content or "" # if this is the final glean, don't bother updating the continuation flag if i >= self._max_gleanings - 1: diff --git a/graphrag/index/graph/extractors/summarize/description_summary_extractor.py b/graphrag/index/graph/extractors/summarize/description_summary_extractor.py index 8459497fd5..ee59c649c6 100644 --- a/graphrag/index/graph/extractors/summarize/description_summary_extractor.py +++ b/graphrag/index/graph/extractors/summarize/description_summary_extractor.py @@ -6,9 +6,10 @@ import json from dataclasses import dataclass +from fnllm import ChatLLM + from graphrag.index.typing import ErrorHandlerFn from graphrag.index.utils.tokens import num_tokens_from_string -from graphrag.llm import CompletionLLM from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT # Max token size for input prompts @@ -28,7 +29,7 @@ class SummarizationResult: class SummarizeExtractor: """Unipartite graph extractor class definition.""" - _llm: CompletionLLM + _llm: ChatLLM _entity_name_key: str _input_descriptions_key: str _summarization_prompt: str @@ -38,7 +39,7 @@ class SummarizeExtractor: def __init__( self, - llm_invoker: CompletionLLM, + llm_invoker: ChatLLM, entity_name_key: str | None = None, input_descriptions_key: str | None = None, summarization_prompt: str | None = None, @@ -126,15 +127,14 @@ async def _summarize_descriptions_with_llm( ): """Summarize descriptions using the LLM.""" response = await self._llm( - self._summarization_prompt, - name="summarize", - variables={ + self._summarization_prompt.format(**{ self._entity_name_key: json.dumps(id, ensure_ascii=False), self._input_descriptions_key: json.dumps( sorted(descriptions), ensure_ascii=False ), - }, + }), + name="summarize", model_parameters={"max_tokens": self._max_summary_length}, ) # Calculate result - return str(response.output) + return str(response.output.content) diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index 4354bd3f3c..07b774c434 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -5,24 +5,27 @@ from __future__ import annotations -import asyncio import logging from typing import TYPE_CHECKING, Any -from graphrag.config.enums import LLMType -from graphrag.llm import ( - CompletionLLM, - EmbeddingLLM, - LLMCache, - LLMLimiter, - MockCompletionLLM, - OpenAIConfiguration, +from fnllm import ChatLLM, EmbeddingsLLM, JsonStrategy, LLMEvents +from fnllm.caching import Cache as LLMCache +from fnllm.openai import ( + AzureOpenAIConfig, + OpenAIConfig, + PublicOpenAIConfig, create_openai_chat_llm, create_openai_client, - create_openai_completion_llm, - create_openai_embedding_llm, - create_tpm_rpm_limiters, + create_openai_embeddings_llm, ) +from fnllm.openai.types.chat.parameters import OpenAIChatParameters +from pydantic import TypeAdapter + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.config.models.llm_parameters import LLMParameters + +from .mock_llm import MockChatLLM if TYPE_CHECKING: from datashaper import VerbCallbacks @@ -32,30 +35,91 @@ log = logging.getLogger(__name__) -_semaphores: dict[str, asyncio.Semaphore] = {} -_rate_limiters: dict[str, LLMLimiter] = {} + +class GraphRagLLMEvents(LLMEvents): + """LLM events handler that calls the error handler.""" + + def __init__(self, on_error: ErrorHandlerFn): + self._on_error = on_error + + async def on_error( + self, + error: BaseException | None, + traceback: str | None = None, + arguments: dict[str, Any] | None = None, + ) -> None: + """Handle an fnllm error.""" + self._on_error(error, traceback, arguments) + + +class GraphRagLLMCache(LLMCache): + """A cache for the pipeline.""" + + def __init__(self, cache: PipelineCache): + self._cache = cache + + async def has(self, key: str) -> bool: + """Check if the cache has a value.""" + return await self._cache.has(key) + + async def get(self, key: str) -> Any | None: + """Retrieve a value from the cache.""" + return await self._cache.get(key) + + async def set( + self, key: str, value: Any, metadata: dict[str, Any] | None = None + ) -> None: + """Write a value into the cache.""" + await self._cache.set(key, value, metadata) + + async def remove(self, key: str) -> None: + """Remove a value from the cache.""" + await self._cache.delete(key) + + async def clear(self) -> None: + """Clear the cache.""" + await self._cache.clear() + + def child(self, key: str): + """Create a child cache.""" + child_cache = self._cache.child(key) + return GraphRagLLMCache(child_cache) + + +def create_cache(cache: PipelineCache | None, name: str) -> LLMCache | None: + """Create an LLM cache from a pipeline cache.""" + if cache is None: + return None + return GraphRagLLMCache(cache).child(name) + + +def read_llm_params(llm_args: dict[str, Any]) -> LLMParameters: + """Read the LLM parameters from the arguments.""" + if llm_args == {}: + msg = "LLM arguments are required" + raise ValueError(msg) + return TypeAdapter(LLMParameters).validate_python(llm_args) def load_llm( name: str, - llm_type: LLMType, + config: LLMParameters, + *, callbacks: VerbCallbacks, cache: PipelineCache | None, - llm_config: dict[str, Any] | None = None, chat_only=False, -) -> CompletionLLM: +) -> ChatLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) + llm_type = config.type if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" raise ValueError(msg) - if cache is not None: - cache = cache.child(name) loader = loaders[llm_type] - return loader["load"](on_error, cache, llm_config or {}) + return loader["load"](on_error, create_cache(cache, name), config) msg = f"Unknown LLM type {llm_type}" raise ValueError(msg) @@ -63,21 +127,22 @@ def load_llm( def load_llm_embeddings( name: str, - llm_type: LLMType, + llm_config: LLMParameters, + *, callbacks: VerbCallbacks, cache: PipelineCache | None, - llm_config: dict[str, Any] | None = None, chat_only=False, -) -> EmbeddingLLM: +) -> EmbeddingsLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) + llm_type = llm_config.type if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" raise ValueError(msg) - if cache is not None: - cache = cache.child(name) - return loaders[llm_type]["load"](on_error, cache, llm_config or {}) + return loaders[llm_type]["load"]( + on_error, create_cache(cache, name), llm_config or {} + ) msg = f"Unknown LLM type {llm_type}" raise ValueError(msg) @@ -94,130 +159,108 @@ def on_error( return on_error -def _load_openai_completion_llm( - on_error: ErrorHandlerFn, - cache: LLMCache, - config: dict[str, Any], - azure=False, -): - return _create_openai_completion_llm( - OpenAIConfiguration({ - **_get_base_config(config), - "model": config.get("model", "gpt-4-turbo-preview"), - "deployment_name": config.get("deployment_name"), - "temperature": config.get("temperature", 0.0), - "frequency_penalty": config.get("frequency_penalty", 0), - "presence_penalty": config.get("presence_penalty", 0), - "top_p": config.get("top_p", 1), - "max_tokens": config.get("max_tokens", 4000), - "n": config.get("n"), - }), - on_error, - cache, - azure, - ) - - def _load_openai_chat_llm( on_error: ErrorHandlerFn, cache: LLMCache, - config: dict[str, Any], + config: LLMParameters, azure=False, ): return _create_openai_chat_llm( - OpenAIConfiguration({ - # Set default values - **_get_base_config(config), - "model": config.get("model", "gpt-4-turbo-preview"), - "deployment_name": config.get("deployment_name"), - "temperature": config.get("temperature", 0.0), - "frequency_penalty": config.get("frequency_penalty", 0), - "presence_penalty": config.get("presence_penalty", 0), - "top_p": config.get("top_p", 1), - "max_tokens": config.get("max_tokens"), - "n": config.get("n"), - }), + _create_openai_config(config, azure), on_error, cache, - azure, ) def _load_openai_embeddings_llm( on_error: ErrorHandlerFn, cache: LLMCache, - config: dict[str, Any], + config: LLMParameters, azure=False, ): - # TODO: Inject Cache return _create_openai_embeddings_llm( - OpenAIConfiguration({ - **_get_base_config(config), - "model": config.get( - "embeddings_model", config.get("model", "text-embedding-3-small") - ), - "deployment_name": config.get("deployment_name"), - }), + _create_openai_config(config, azure), on_error, cache, - azure, ) -def _load_azure_openai_completion_llm( - on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] -): - return _load_openai_completion_llm(on_error, cache, config, True) +def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig: + encoding_model = config.encoding_model or defs.ENCODING_MODEL + json_strategy = ( + JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE + ) + chat_parameters = OpenAIChatParameters( + frequency_penalty=config.frequency_penalty, + presence_penalty=config.presence_penalty, + top_p=config.top_p, + max_tokens=config.max_tokens, + n=config.n, + temperature=config.temperature, + ) + if azure: + if config.api_base is None: + msg = "Azure OpenAI Chat LLM requires an API base" + raise ValueError(msg) + + audience = config.audience or defs.AZURE_AUDIENCE + return AzureOpenAIConfig( + api_key=config.api_key, + endpoint=config.api_base, + json_strategy=json_strategy, + api_version=config.api_version, + organization=config.organization, + max_retries=config.max_retries, + max_retry_wait=config.max_retry_wait, + requests_per_minute=config.requests_per_minute, + tokens_per_minute=config.tokens_per_minute, + cognitive_services_endpoint=audience, + timeout=config.request_timeout, + max_concurrency=config.concurrent_requests, + model=config.model, + encoding=encoding_model, + deployment=config.deployment_name, + chat_parameters=chat_parameters, + ) + return PublicOpenAIConfig( + api_key=config.api_key, + base_url=config.proxy, + json_strategy=json_strategy, + organization=config.organization, + max_retries=config.max_retries, + max_retry_wait=config.max_retry_wait, + requests_per_minute=config.requests_per_minute, + tokens_per_minute=config.tokens_per_minute, + timeout=config.request_timeout, + max_concurrency=config.concurrent_requests, + model=config.model, + encoding=encoding_model, + chat_parameters=chat_parameters, + ) def _load_azure_openai_chat_llm( - on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] + on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters ): return _load_openai_chat_llm(on_error, cache, config, True) def _load_azure_openai_embeddings_llm( - on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] + on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters ): return _load_openai_embeddings_llm(on_error, cache, config, True) -def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: - api_key = config.get("api_key") - - return { - # Pass in all parameterized values - **config, - # Set default values - "api_key": api_key, - "api_base": config.get("api_base"), - "api_version": config.get("api_version"), - "organization": config.get("organization"), - "proxy": config.get("proxy"), - "max_retries": config.get("max_retries", 10), - "request_timeout": config.get("request_timeout", 60.0), - "model_supports_json": config.get("model_supports_json"), - "concurrent_requests": config.get("concurrent_requests", 4), - "encoding_model": config.get("encoding_model", "cl100k_base"), - "audience": config.get("audience"), - } - - def _load_static_response( - _on_error: ErrorHandlerFn, _cache: PipelineCache, config: dict[str, Any] -) -> CompletionLLM: - return MockCompletionLLM(config.get("responses", [])) + _on_error: ErrorHandlerFn, _cache: PipelineCache, config: LLMParameters +) -> ChatLLM: + if config.responses is None: + msg = "Static response LLM requires responses" + raise ValueError(msg) + return MockChatLLM(config.responses or []) loaders = { - LLMType.OpenAI: { - "load": _load_openai_completion_llm, - "chat": False, - }, - LLMType.AzureOpenAI: { - "load": _load_azure_openai_completion_llm, - "chat": False, - }, LLMType.OpenAIChat: { "load": _load_openai_chat_llm, "chat": True, @@ -242,71 +285,30 @@ def _load_static_response( def _create_openai_chat_llm( - configuration: OpenAIConfiguration, + configuration: OpenAIConfig, on_error: ErrorHandlerFn, cache: LLMCache, - azure=False, -) -> CompletionLLM: +) -> ChatLLM: """Create an openAI chat llm.""" - client = create_openai_client(configuration=configuration, azure=azure) - limiter = _create_limiter(configuration) - semaphore = _create_semaphore(configuration) + client = create_openai_client(configuration) return create_openai_chat_llm( - client, configuration, cache, limiter, semaphore, on_error=on_error - ) - - -def _create_openai_completion_llm( - configuration: OpenAIConfiguration, - on_error: ErrorHandlerFn, - cache: LLMCache, - azure=False, -) -> CompletionLLM: - """Create an openAI completion llm.""" - client = create_openai_client(configuration=configuration, azure=azure) - limiter = _create_limiter(configuration) - semaphore = _create_semaphore(configuration) - return create_openai_completion_llm( - client, configuration, cache, limiter, semaphore, on_error=on_error + configuration, + client=client, + cache=cache, + events=GraphRagLLMEvents(on_error), ) def _create_openai_embeddings_llm( - configuration: OpenAIConfiguration, + configuration: OpenAIConfig, on_error: ErrorHandlerFn, cache: LLMCache, - azure=False, -) -> EmbeddingLLM: +) -> EmbeddingsLLM: """Create an openAI embeddings llm.""" - client = create_openai_client(configuration=configuration, azure=azure) - limiter = _create_limiter(configuration) - semaphore = _create_semaphore(configuration) - return create_openai_embedding_llm( - client, configuration, cache, limiter, semaphore, on_error=on_error + client = create_openai_client(configuration) + return create_openai_embeddings_llm( + configuration, + client=client, + cache=cache, + events=GraphRagLLMEvents(on_error), ) - - -def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: - limit_name = configuration.model or configuration.deployment_name or "default" - if limit_name not in _rate_limiters: - tpm = configuration.tokens_per_minute - rpm = configuration.requests_per_minute - log.info("create TPM/RPM limiter for %s: TPM=%s, RPM=%s", limit_name, tpm, rpm) - _rate_limiters[limit_name] = create_tpm_rpm_limiters(configuration) - return _rate_limiters[limit_name] - - -def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None: - limit_name = configuration.model or configuration.deployment_name or "default" - concurrency = configuration.concurrent_requests - - # bypass the semaphore if concurrency is zero - if not concurrency: - log.info("no concurrency limiter for %s", limit_name) - return None - - if limit_name not in _semaphores: - log.info("create concurrency limiter for %s: %s", limit_name, concurrency) - _semaphores[limit_name] = asyncio.Semaphore(concurrency) - - return _semaphores[limit_name] diff --git a/graphrag/index/llm/mock_llm.py b/graphrag/index/llm/mock_llm.py new file mode 100644 index 0000000000..eeae24df00 --- /dev/null +++ b/graphrag/index/llm/mock_llm.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A mock LLM that returns the given responses.""" + +from dataclasses import dataclass +from typing import Any, cast + +from fnllm import ChatLLM, LLMInput, LLMOutput +from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters +from pydantic import BaseModel +from typing_extensions import Unpack + + +@dataclass +class ContentResponse: + """A mock content-only response.""" + + content: str + + +class MockChatLLM(ChatLLM): + """A mock LLM that returns the given responses.""" + + def __init__(self, responses: list[str | BaseModel], json: bool = False): + self.responses = responses + self.response_index = 0 + + async def __call__( + self, + prompt: str, + **kwargs: Unpack[LLMInput[TJsonModel, THistoryEntry, TModelParameters]], + ) -> LLMOutput[Any, TJsonModel, THistoryEntry]: + """Return the next response in the list.""" + response = self.responses[self.response_index % len(self.responses)] + self.response_index += 1 + + parsed_json = response if isinstance(response, BaseModel) else None + response = ( + response.model_dump_json() if isinstance(response, BaseModel) else response + ) + + return LLMOutput( + output=ContentResponse(content=response), + parsed_json=cast(TJsonModel, parsed_json), + ) diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/strategies/openai.py index 15ce0a1ff9..36be774203 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/strategies/openai.py @@ -9,14 +9,16 @@ import numpy as np from datashaper import ProgressTicker, VerbCallbacks, progress_ticker +from fnllm import EmbeddingsLLM +from pydantic import TypeAdapter import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.llm.load_llm import load_llm_embeddings from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult from graphrag.index.text_splitting.text_splitting import TokenTextSplitter from graphrag.index.utils.is_null import is_null -from graphrag.llm import EmbeddingLLM, OpenAIConfiguration log = logging.getLogger(__name__) @@ -31,12 +33,11 @@ async def run( if is_null(input): return TextEmbeddingResult(embeddings=None) - llm_config = args.get("llm", {}) + llm_config = TypeAdapter(LLMParameters).validate_python(args.get("llm", {})) batch_size = args.get("batch_size", 16) batch_max_tokens = args.get("batch_max_tokens", 8191) - oai_config = OpenAIConfiguration(llm_config) - splitter = _get_splitter(oai_config, batch_max_tokens) - llm = _get_llm(oai_config, callbacks, cache) + splitter = _get_splitter(llm_config, batch_max_tokens) + llm = _get_llm(llm_config, callbacks, cache) semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4)) # Break up the input texts. The sizes here indicate how many snippets are in each input text @@ -64,9 +65,7 @@ async def run( return TextEmbeddingResult(embeddings=embeddings) -def _get_splitter( - config: OpenAIConfiguration, batch_max_tokens: int -) -> TokenTextSplitter: +def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSplitter: return TokenTextSplitter( encoding_name=config.encoding_model or defs.ENCODING_MODEL, chunk_size=batch_max_tokens, @@ -74,22 +73,20 @@ def _get_splitter( def _get_llm( - config: OpenAIConfiguration, + config: LLMParameters, callbacks: VerbCallbacks, cache: PipelineCache, -) -> EmbeddingLLM: - llm_type = config.lookup("type", "Unknown") +) -> EmbeddingsLLM: return load_llm_embeddings( "text_embedding", - llm_type, - callbacks, - cache, - config.raw_config, + config, + callbacks=callbacks, + cache=cache, ) async def _execute( - llm: EmbeddingLLM, + llm: EmbeddingsLLM, chunks: list[list[str]], tick: ProgressTicker, semaphore: asyncio.Semaphore, @@ -97,7 +94,7 @@ async def _execute( async def embed(chunk: list[str]): async with semaphore: chunk_embeddings = await llm(chunk) - result = np.array(chunk_embeddings.output) + result = np.array(chunk_embeddings.output.embeddings) tick(1) return result diff --git a/graphrag/index/operations/extract_covariates/strategies.py b/graphrag/index/operations/extract_covariates/strategies.py index 46d0ca0c2b..edf4c3a670 100644 --- a/graphrag/index/operations/extract_covariates/strategies.py +++ b/graphrag/index/operations/extract_covariates/strategies.py @@ -7,16 +7,16 @@ from typing import Any from datashaper import VerbCallbacks +from fnllm import ChatLLM import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.claims import ClaimExtractor -from graphrag.index.llm.load_llm import load_llm +from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_covariates.typing import ( Covariate, CovariateExtractionResult, ) -from graphrag.llm import CompletionLLM async def run_graph_intelligence( @@ -28,16 +28,15 @@ async def run_graph_intelligence( strategy_config: dict[str, Any], ) -> CovariateExtractionResult: """Run the Claim extraction chain.""" - llm_config = strategy_config.get("llm", {}) - llm_type = llm_config.get("type") - llm = load_llm("claim_extraction", llm_type, callbacks, cache, llm_config) + llm_config = read_llm_params(strategy_config.get("llm", {})) + llm = load_llm("claim_extraction", llm_config, callbacks=callbacks, cache=cache) return await _execute( llm, input, entity_types, resolved_entities_map, callbacks, strategy_config ) async def _execute( - llm: CompletionLLM, + llm: ChatLLM, texts: Iterable[str], entity_types: list[str], resolved_entities_map: dict[str, str], diff --git a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py index eaca006825..308ff29d05 100644 --- a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py +++ b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py @@ -5,11 +5,12 @@ import networkx as nx from datashaper import VerbCallbacks +from fnllm import ChatLLM import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors import GraphExtractor -from graphrag.index.llm.load_llm import load_llm +from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_entities.strategies.typing import ( Document, EntityExtractionResult, @@ -21,7 +22,6 @@ TextSplitter, TokenTextSplitter, ) -from graphrag.llm import CompletionLLM async def run_graph_intelligence( @@ -32,14 +32,13 @@ async def run_graph_intelligence( args: StrategyConfig, ) -> EntityExtractionResult: """Run the graph intelligence entity extraction strategy.""" - llm_config = args.get("llm", {}) - llm_type = llm_config.get("type") - llm = load_llm("entity_extraction", llm_type, callbacks, cache, llm_config) + llm_config = read_llm_params(args.get("llm", {})) + llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache) return await run_extract_entities(llm, docs, entity_types, callbacks, args) async def run_extract_entities( - llm: CompletionLLM, + llm: ChatLLM, docs: list[Document], entity_types: EntityTypes, callbacks: VerbCallbacks | None, diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 4c3d7a3fe4..8142572b6d 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -3,23 +3,23 @@ """A module containing run, _run_extractor and _load_nodes_edges_for_claim_chain methods definition.""" -import json import logging import traceback from datashaper import VerbCallbacks +from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.community_reports import ( CommunityReportsExtractor, ) -from graphrag.index.llm.load_llm import load_llm +from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, + Finding, StrategyConfig, ) from graphrag.index.utils.rate_limiter import RateLimiter -from graphrag.llm import CompletionLLM DEFAULT_CHUNK_SIZE = 3000 @@ -35,14 +35,13 @@ async def run_graph_intelligence( args: StrategyConfig, ) -> CommunityReport | None: """Run the graph intelligence entity extraction strategy.""" - llm_config = args.get("llm", {}) - llm_type = llm_config.get("type") - llm = load_llm("community_reporting", llm_type, callbacks, cache, llm_config) + llm_config = read_llm_params(args.get("llm", {})) + llm = load_llm("community_reporting", llm_config, callbacks=callbacks, cache=cache) return await _run_extractor(llm, community, input, level, args, callbacks) async def _run_extractor( - llm: CompletionLLM, + llm: ChatLLM, community: str | int, input: str, level: int, @@ -64,7 +63,7 @@ async def _run_extractor( await rate_limiter.acquire() results = await extractor({"input_text": input}) report = results.structured_output - if report is None or len(report.keys()) == 0: + if report is None: log.warning("No report found for community: %s", community) return None @@ -72,23 +71,17 @@ async def _run_extractor( community=community, full_content=results.output, level=level, - rank=_parse_rank(report), - title=report.get("title", f"Community Report: {community}"), - rank_explanation=report.get("rating_explanation", ""), - summary=report.get("summary", ""), - findings=report.get("findings", []), - full_content_json=json.dumps(report, indent=4, ensure_ascii=False), + rank=report.rating, + title=report.title, + rank_explanation=report.rating_explanation, + summary=report.summary, + findings=[ + Finding(explanation=f.explanation, summary=f.summary) + for f in report.findings + ], + full_content_json=report.model_dump_json(indent=4), ) except Exception as e: log.exception("Error processing community: %s", community) callbacks.error("Community Report Extraction Error", e, traceback.format_exc()) return None - - -def _parse_rank(report: dict) -> float: - rank = report.get("rating", -1) - try: - return float(rank) - except ValueError: - log.exception("Error parsing rank: %s defaulting to -1", rank) - return -1 diff --git a/graphrag/index/operations/summarize_descriptions/strategies.py b/graphrag/index/operations/summarize_descriptions/strategies.py index 0990c5e884..0538f0e225 100644 --- a/graphrag/index/operations/summarize_descriptions/strategies.py +++ b/graphrag/index/operations/summarize_descriptions/strategies.py @@ -4,15 +4,15 @@ """A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" from datashaper import VerbCallbacks +from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.summarize import SummarizeExtractor -from graphrag.index.llm.load_llm import load_llm +from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_descriptions.typing import ( StrategyConfig, SummarizedDescriptionResult, ) -from graphrag.llm import CompletionLLM async def run_graph_intelligence( @@ -23,14 +23,15 @@ async def run_graph_intelligence( args: StrategyConfig, ) -> SummarizedDescriptionResult: """Run the graph intelligence entity extraction strategy.""" - llm_config = args.get("llm", {}) - llm_type = llm_config.get("type") - llm = load_llm("summarize_descriptions", llm_type, callbacks, cache, llm_config) + llm_config = read_llm_params(args.get("llm", {})) + llm = load_llm( + "summarize_descriptions", llm_config, callbacks=callbacks, cache=cache + ) return await run_summarize_descriptions(llm, id, descriptions, callbacks, args) async def run_summarize_descriptions( - llm: CompletionLLM, + llm: ChatLLM, id: str | tuple[str, str], descriptions: list[str], callbacks: VerbCallbacks, diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index 11d7fd8390..4ea6a56299 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -20,10 +20,9 @@ def validate_config_names( # Validate Chat LLM configs llm = load_llm( "test-llm", - parameters.llm.type, - NoopVerbCallbacks(), - None, - parameters.llm.model_dump(), + parameters.llm, + callbacks=NoopVerbCallbacks(), + cache=None, ) try: asyncio.run(llm("This is an LLM connectivity test. Say Hello World")) @@ -35,10 +34,9 @@ def validate_config_names( # Validate Embeddings LLM configs embed_llm = load_llm_embeddings( "test-embed-llm", - parameters.embeddings.llm.type, - NoopVerbCallbacks(), - None, - parameters.embeddings.llm.model_dump(), + parameters.embeddings.llm, + callbacks=NoopVerbCallbacks(), + cache=None, ) try: asyncio.run(embed_llm(["This is an LLM Embedding Test String"])) diff --git a/graphrag/llm/__init__.py b/graphrag/llm/__init__.py deleted file mode 100644 index 609be951b2..0000000000 --- a/graphrag/llm/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Datashaper OpenAI Utilities package.""" - -from .base import BaseLLM, CachingLLM, RateLimitingLLM -from .errors import RetriesExhaustedError -from .limiting import ( - CompositeLLMLimiter, - LLMLimiter, - NoopLLMLimiter, - TpmRpmLLMLimiter, - create_tpm_rpm_limiters, -) -from .mock import MockChatLLM, MockCompletionLLM -from .openai import ( - OpenAIChatLLM, - OpenAIClientTypes, - OpenAICompletionLLM, - OpenAIConfiguration, - OpenAIEmbeddingsLLM, - create_openai_chat_llm, - create_openai_client, - create_openai_completion_llm, - create_openai_embedding_llm, -) -from .types import ( - LLM, - CompletionInput, - CompletionLLM, - CompletionOutput, - EmbeddingInput, - EmbeddingLLM, - EmbeddingOutput, - ErrorHandlerFn, - IsResponseValidFn, - LLMCache, - LLMConfig, - LLMInput, - LLMInvocationFn, - LLMInvocationResult, - LLMOutput, - OnCacheActionFn, -) - -__all__ = [ - # LLM Types - "LLM", - "BaseLLM", - "CachingLLM", - "CompletionInput", - "CompletionLLM", - "CompletionOutput", - "CompositeLLMLimiter", - "EmbeddingInput", - "EmbeddingLLM", - "EmbeddingOutput", - # Callbacks - "ErrorHandlerFn", - "IsResponseValidFn", - # Cache - "LLMCache", - "LLMConfig", - # LLM I/O Types - "LLMInput", - "LLMInvocationFn", - "LLMInvocationResult", - "LLMLimiter", - "LLMOutput", - "MockChatLLM", - # Mock - "MockCompletionLLM", - "NoopLLMLimiter", - "OnCacheActionFn", - "OpenAIChatLLM", - "OpenAIClientTypes", - "OpenAICompletionLLM", - # OpenAI - "OpenAIConfiguration", - "OpenAIEmbeddingsLLM", - "RateLimitingLLM", - # Errors - "RetriesExhaustedError", - "TpmRpmLLMLimiter", - "create_openai_chat_llm", - "create_openai_client", - "create_openai_completion_llm", - "create_openai_embedding_llm", - # Limiters - "create_tpm_rpm_limiters", -] diff --git a/graphrag/llm/base/__init__.py b/graphrag/llm/base/__init__.py deleted file mode 100644 index dd5ebf9050..0000000000 --- a/graphrag/llm/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Base LLM Implementations.""" - -from .base_llm import BaseLLM -from .caching_llm import CachingLLM -from .rate_limiting_llm import RateLimitingLLM - -__all__ = ["BaseLLM", "CachingLLM", "RateLimitingLLM"] diff --git a/graphrag/llm/base/_create_cache_key.py b/graphrag/llm/base/_create_cache_key.py deleted file mode 100644 index b5fdd839bc..0000000000 --- a/graphrag/llm/base/_create_cache_key.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Cache key generation utils.""" - -import hashlib -import json - - -def _llm_string(params: dict) -> str: - # New version of the cache is not including n in the params dictionary - # This avoids creating a new cache key for the same prompt - if "max_tokens" in params and "n" not in params: - params["n"] = None - return str(sorted((k, v) for k, v in params.items())) - - -def _hash(_input: str) -> str: - """Use a deterministic hashing approach.""" - return hashlib.md5(_input.encode()).hexdigest() # noqa S324 - - -def create_hash_key( - operation: str, prompt: str, parameters: dict, history: list[dict] | None -) -> str: - """Compute cache key from prompt and associated model and settings. - - Args: - prompt (str): The prompt run through the language model. - llm_string (str): The language model version and settings. - - Returns - ------- - str: The cache key. - """ - llm_string = _llm_string(parameters) - history_string = _hash(json.dumps(history)) if history else None - hash_string = ( - _hash(prompt + llm_string + history_string) - if history_string - else _hash(prompt + llm_string) - ) - return f"{operation}-{hash_string}" diff --git a/graphrag/llm/base/base_llm.py b/graphrag/llm/base/base_llm.py deleted file mode 100644 index 65fffd1b95..0000000000 --- a/graphrag/llm/base/base_llm.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Base LLM class definition.""" - -import traceback -from abc import ABC, abstractmethod -from typing import Generic, TypeVar - -from openai import RateLimitError -from typing_extensions import Unpack - -from graphrag.llm.types import ( - LLM, - ErrorHandlerFn, - LLMInput, - LLMOutput, -) - -TIn = TypeVar("TIn") -TOut = TypeVar("TOut") - - -class BaseLLM(ABC, LLM[TIn, TOut], Generic[TIn, TOut]): - """LLM Implementation class definition.""" - - _on_error: ErrorHandlerFn | None - - def on_error(self, on_error: ErrorHandlerFn | None) -> None: - """Set the error handler function.""" - self._on_error = on_error - - @abstractmethod - async def _execute_llm( - self, - input: TIn, - **kwargs: Unpack[LLMInput], - ) -> TOut | None: - pass - - async def __call__( - self, - input: TIn, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[TOut]: - """Invoke the LLM.""" - is_json = kwargs.get("json") or False - if is_json: - return await self._invoke_json(input, **kwargs) - return await self._invoke(input, **kwargs) - - async def _invoke(self, input: TIn, **kwargs: Unpack[LLMInput]) -> LLMOutput[TOut]: - try: - output = await self._execute_llm(input, **kwargs) - return LLMOutput(output=output) - except RateLimitError: - # for improved readability, do not log rate limit exceptions, - # they are logged/handled elsewhere - raise - except Exception as e: - stack_trace = traceback.format_exc() - if self._on_error: - self._on_error(e, stack_trace, {"input": input}) - raise - - async def _invoke_json( - self, input: TIn, **kwargs: Unpack[LLMInput] - ) -> LLMOutput[TOut]: - msg = "JSON output not supported by this LLM" - raise NotImplementedError(msg) diff --git a/graphrag/llm/base/caching_llm.py b/graphrag/llm/base/caching_llm.py deleted file mode 100644 index c039de5122..0000000000 --- a/graphrag/llm/base/caching_llm.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A class to interact with the cache.""" - -import json -from typing import Generic, TypeVar - -from typing_extensions import Unpack - -from graphrag.llm.types import LLM, LLMCache, LLMInput, LLMOutput, OnCacheActionFn - -from ._create_cache_key import create_hash_key - -# If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches -_cache_strategy_version = 2 - -TIn = TypeVar("TIn") -TOut = TypeVar("TOut") - - -def _noop_cache_fn(_k: str, _v: str | None): - pass - - -class CachingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): - """A class to interact with the cache.""" - - _cache: LLMCache - _delegate: LLM[TIn, TOut] - _operation: str - _llm_parameters: dict - _on_cache_hit: OnCacheActionFn - _on_cache_miss: OnCacheActionFn - - def __init__( - self, - delegate: LLM[TIn, TOut], - llm_parameters: dict, - operation: str, - cache: LLMCache, - ): - self._delegate = delegate - self._llm_parameters = llm_parameters - self._cache = cache - self._operation = operation - self._on_cache_hit = _noop_cache_fn - self._on_cache_miss = _noop_cache_fn - - def set_delegate(self, delegate: LLM[TIn, TOut]) -> None: - """Set the delegate LLM. (for testing).""" - self._delegate = delegate - - def on_cache_hit(self, fn: OnCacheActionFn | None) -> None: - """Set the function to call when a cache hit occurs.""" - self._on_cache_hit = fn or _noop_cache_fn - - def on_cache_miss(self, fn: OnCacheActionFn | None) -> None: - """Set the function to call when a cache miss occurs.""" - self._on_cache_miss = fn or _noop_cache_fn - - def _cache_key( - self, input: TIn, name: str | None, args: dict, history: list[dict] | None - ) -> str: - json_input = json.dumps(input) - tag = ( - f"{name}-{self._operation}-v{_cache_strategy_version}" - if name is not None - else self._operation - ) - return create_hash_key(tag, json_input, args, history) - - async def __call__( - self, - input: TIn, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[TOut]: - """Execute the LLM.""" - # Check for an Existing cache item - name = kwargs.get("name") - history_in = kwargs.get("history") or None - llm_args = {**self._llm_parameters, **(kwargs.get("model_parameters") or {})} - cache_key = self._cache_key(input, name, llm_args, history_in) - cached_result = await self._cache.get(cache_key) - - if cached_result: - self._on_cache_hit(cache_key, name) - return LLMOutput( - output=cached_result, - ) - - # Report the Cache Miss - self._on_cache_miss(cache_key, name) - - # Compute the new result - result = await self._delegate(input, **kwargs) - - # Cache the new result - if result.output is not None: - await self._cache.set( - cache_key, - result.output, - { - "input": input, - "parameters": llm_args, - "history": history_in, - }, - ) - return result diff --git a/graphrag/llm/base/rate_limiting_llm.py b/graphrag/llm/base/rate_limiting_llm.py deleted file mode 100644 index 5e2082475f..0000000000 --- a/graphrag/llm/base/rate_limiting_llm.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Rate limiting LLM implementation.""" - -import asyncio -import logging -from collections.abc import Callable -from typing import Any, Generic, TypeVar - -from tenacity import ( - AsyncRetrying, - retry_if_exception_type, - stop_after_attempt, - wait_exponential_jitter, -) -from typing_extensions import Unpack - -from graphrag.llm.errors import RetriesExhaustedError -from graphrag.llm.limiting import LLMLimiter -from graphrag.llm.types import ( - LLM, - LLMConfig, - LLMInput, - LLMInvocationFn, - LLMInvocationResult, - LLMOutput, -) - -TIn = TypeVar("TIn") -TOut = TypeVar("TOut") -TRateLimitError = TypeVar("TRateLimitError", bound=BaseException) - -_CANNOT_MEASURE_INPUT_TOKENS_MSG = "cannot measure input tokens" -_CANNOT_MEASURE_OUTPUT_TOKENS_MSG = "cannot measure output tokens" - -log = logging.getLogger(__name__) - - -class RateLimitingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): - """A class to interact with the cache.""" - - _delegate: LLM[TIn, TOut] - _rate_limiter: LLMLimiter | None - _semaphore: asyncio.Semaphore | None - _count_tokens: Callable[[str], int] - _config: LLMConfig - _operation: str - _retryable_errors: list[type[Exception]] - _rate_limit_errors: list[type[Exception]] - _on_invoke: LLMInvocationFn - _extract_sleep_recommendation: Callable[[Any], float] - - def __init__( - self, - delegate: LLM[TIn, TOut], - config: LLMConfig, - operation: str, - retryable_errors: list[type[Exception]], - rate_limit_errors: list[type[Exception]], - rate_limiter: LLMLimiter | None = None, - semaphore: asyncio.Semaphore | None = None, - count_tokens: Callable[[str], int] | None = None, - get_sleep_time: Callable[[BaseException], float] | None = None, - ): - self._delegate = delegate - self._rate_limiter = rate_limiter - self._semaphore = semaphore - self._config = config - self._operation = operation - self._retryable_errors = retryable_errors - self._rate_limit_errors = rate_limit_errors - self._count_tokens = count_tokens or (lambda _s: -1) - self._extract_sleep_recommendation = get_sleep_time or (lambda _e: 0.0) - self._on_invoke = lambda _v: None - - def on_invoke(self, fn: LLMInvocationFn | None) -> None: - """Set the on_invoke function.""" - self._on_invoke = fn or (lambda _v: None) - - def count_request_tokens(self, input: TIn) -> int: - """Count the request tokens on an input request.""" - if isinstance(input, str): - return self._count_tokens(input) - if isinstance(input, list): - result = 0 - for item in input: - if isinstance(item, str): - result += self._count_tokens(item) - elif isinstance(item, dict): - result += self._count_tokens(item.get("content", "")) - else: - raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) - return result - raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) - - def count_response_tokens(self, output: TOut | None) -> int: - """Count the request tokens on an output response.""" - if output is None: - return 0 - if isinstance(output, str): - return self._count_tokens(output) - if isinstance(output, list) and all(isinstance(x, str) for x in output): - return sum(self._count_tokens(item) for item in output) - if isinstance(output, list): - # Embedding response, don't count it - return 0 - raise TypeError(_CANNOT_MEASURE_OUTPUT_TOKENS_MSG) - - async def __call__( - self, - input: TIn, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[TOut]: - """Execute the LLM with semaphore & rate limiting.""" - name = kwargs.get("name", "Process") - attempt_number = 0 - call_times: list[float] = [] - input_tokens = self.count_request_tokens(input) - max_retries = self._config.max_retries or 10 - max_retry_wait = self._config.max_retry_wait or 10 - follow_recommendation = self._config.sleep_on_rate_limit_recommendation - retryer = AsyncRetrying( - stop=stop_after_attempt(max_retries), - wait=wait_exponential_jitter(max=max_retry_wait), - reraise=True, - retry=retry_if_exception_type(tuple(self._retryable_errors)), - ) - - async def sleep_for(time: float | None) -> None: - log.warning( - "%s failed to invoke LLM %s/%s attempts. Cause: rate limit exceeded, will retry. Recommended sleep for %d seconds. Follow recommendation? %s", - name, - attempt_number, - max_retries, - time, - follow_recommendation, - ) - if follow_recommendation and time: - await asyncio.sleep(time) - raise - - async def do_attempt() -> LLMOutput[TOut]: - nonlocal call_times - call_start = asyncio.get_event_loop().time() - try: - return await self._delegate(input, **kwargs) - except BaseException as e: - if isinstance(e, tuple(self._rate_limit_errors)): - sleep_time = self._extract_sleep_recommendation(e) - await sleep_for(sleep_time) - raise - finally: - call_end = asyncio.get_event_loop().time() - call_times.append(call_end - call_start) - - async def execute_with_retry() -> tuple[LLMOutput[TOut], float]: - nonlocal attempt_number - async for attempt in retryer: - with attempt: - if self._rate_limiter and input_tokens > 0: - await self._rate_limiter.acquire(input_tokens) - start = asyncio.get_event_loop().time() - attempt_number += 1 - return await do_attempt(), start - - log.error("Retries exhausted for %s", name) - raise RetriesExhaustedError(name, max_retries) - - result: LLMOutput[TOut] - start = 0.0 - - if self._semaphore is None: - result, start = await execute_with_retry() - else: - async with self._semaphore: - result, start = await execute_with_retry() - - end = asyncio.get_event_loop().time() - output_tokens = self.count_response_tokens(result.output) - if self._rate_limiter and output_tokens > 0: - await self._rate_limiter.acquire(output_tokens) - - invocation_result = LLMInvocationResult( - result=result, - name=name, - num_retries=attempt_number - 1, - total_time=end - start, - call_times=call_times, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - self._handle_invoke_result(invocation_result) - return result - - def _handle_invoke_result( - self, result: LLMInvocationResult[LLMOutput[TOut]] - ) -> None: - log.info( - 'perf - llm.%s "%s" with %s retries took %s. input_tokens=%d, output_tokens=%d', - self._operation, - result.name, - result.num_retries, - result.total_time, - result.input_tokens, - result.output_tokens, - ) - self._on_invoke(result) diff --git a/graphrag/llm/errors.py b/graphrag/llm/errors.py deleted file mode 100644 index 01136359de..0000000000 --- a/graphrag/llm/errors.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Error definitions for the OpenAI DataShaper package.""" - - -class RetriesExhaustedError(RuntimeError): - """Retries exhausted error.""" - - def __init__(self, name: str, num_retries: int) -> None: - """Init method definition.""" - super().__init__(f"Operation '{name}' failed - {num_retries} retries exhausted") diff --git a/graphrag/llm/limiting/__init__.py b/graphrag/llm/limiting/__init__.py deleted file mode 100644 index 4f7933d1a8..0000000000 --- a/graphrag/llm/limiting/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM limiters module.""" - -from .composite_limiter import CompositeLLMLimiter -from .create_limiters import create_tpm_rpm_limiters -from .llm_limiter import LLMLimiter -from .noop_llm_limiter import NoopLLMLimiter -from .tpm_rpm_limiter import TpmRpmLLMLimiter - -__all__ = [ - "CompositeLLMLimiter", - "LLMLimiter", - "NoopLLMLimiter", - "TpmRpmLLMLimiter", - "create_tpm_rpm_limiters", -] diff --git a/graphrag/llm/limiting/composite_limiter.py b/graphrag/llm/limiting/composite_limiter.py deleted file mode 100644 index 7bcf9195b2..0000000000 --- a/graphrag/llm/limiting/composite_limiter.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing Composite Limiter class definition.""" - -from .llm_limiter import LLMLimiter - - -class CompositeLLMLimiter(LLMLimiter): - """Composite Limiter class definition.""" - - _limiters: list[LLMLimiter] - - def __init__(self, limiters: list[LLMLimiter]): - """Init method definition.""" - self._limiters = limiters - - @property - def needs_token_count(self) -> bool: - """Whether this limiter needs the token count to be passed in.""" - return any(limiter.needs_token_count for limiter in self._limiters) - - async def acquire(self, num_tokens: int = 1) -> None: - """Call method definition.""" - for limiter in self._limiters: - await limiter.acquire(num_tokens) diff --git a/graphrag/llm/limiting/create_limiters.py b/graphrag/llm/limiting/create_limiters.py deleted file mode 100644 index 92df11c1a6..0000000000 --- a/graphrag/llm/limiting/create_limiters.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Create limiters for OpenAI API requests.""" - -import logging - -from aiolimiter import AsyncLimiter - -from graphrag.llm.types import LLMConfig - -from .llm_limiter import LLMLimiter -from .tpm_rpm_limiter import TpmRpmLLMLimiter - -log = logging.getLogger(__name__) - -"""The global TPM limiters.""" - - -def create_tpm_rpm_limiters( - configuration: LLMConfig, -) -> LLMLimiter: - """Get the limiters for a given model name.""" - tpm = configuration.tokens_per_minute - rpm = configuration.requests_per_minute - return TpmRpmLLMLimiter( - None if tpm == 0 else AsyncLimiter(tpm or 50_000), - None if rpm == 0 else AsyncLimiter(rpm or 10_000), - ) diff --git a/graphrag/llm/limiting/llm_limiter.py b/graphrag/llm/limiting/llm_limiter.py deleted file mode 100644 index 1264a84be5..0000000000 --- a/graphrag/llm/limiting/llm_limiter.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Limiting types.""" - -from abc import ABC, abstractmethod - - -class LLMLimiter(ABC): - """LLM Limiter Interface.""" - - @property - @abstractmethod - def needs_token_count(self) -> bool: - """Whether this limiter needs the token count to be passed in.""" - - @abstractmethod - async def acquire(self, num_tokens: int = 1) -> None: - """Acquire a pass through the limiter.""" diff --git a/graphrag/llm/limiting/noop_llm_limiter.py b/graphrag/llm/limiting/noop_llm_limiter.py deleted file mode 100644 index 5147055255..0000000000 --- a/graphrag/llm/limiting/noop_llm_limiter.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""TPM RPM Limiter module.""" - -from .llm_limiter import LLMLimiter - - -class NoopLLMLimiter(LLMLimiter): - """TPM RPM Limiter class definition.""" - - @property - def needs_token_count(self) -> bool: - """Whether this limiter needs the token count to be passed in.""" - return False - - async def acquire(self, num_tokens: int = 1) -> None: - """Call method definition.""" - # do nothing diff --git a/graphrag/llm/limiting/tpm_rpm_limiter.py b/graphrag/llm/limiting/tpm_rpm_limiter.py deleted file mode 100644 index cb6d84e377..0000000000 --- a/graphrag/llm/limiting/tpm_rpm_limiter.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""TPM RPM Limiter module.""" - -from aiolimiter import AsyncLimiter - -from .llm_limiter import LLMLimiter - - -class TpmRpmLLMLimiter(LLMLimiter): - """TPM RPM Limiter class definition.""" - - _tpm_limiter: AsyncLimiter | None - _rpm_limiter: AsyncLimiter | None - - def __init__( - self, tpm_limiter: AsyncLimiter | None, rpm_limiter: AsyncLimiter | None - ): - """Init method definition.""" - self._tpm_limiter = tpm_limiter - self._rpm_limiter = rpm_limiter - - @property - def needs_token_count(self) -> bool: - """Whether this limiter needs the token count to be passed in.""" - return self._tpm_limiter is not None - - async def acquire(self, num_tokens: int = 1) -> None: - """Call method definition.""" - if self._tpm_limiter is not None: - await self._tpm_limiter.acquire(num_tokens) - if self._rpm_limiter is not None: - await self._rpm_limiter.acquire() diff --git a/graphrag/llm/mock/__init__.py b/graphrag/llm/mock/__init__.py deleted file mode 100644 index cd1f000dd1..0000000000 --- a/graphrag/llm/mock/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Mock LLM Implementations.""" - -from .mock_chat_llm import MockChatLLM -from .mock_completion_llm import MockCompletionLLM - -__all__ = [ - "MockChatLLM", - "MockCompletionLLM", -] diff --git a/graphrag/llm/mock/mock_chat_llm.py b/graphrag/llm/mock/mock_chat_llm.py deleted file mode 100644 index b8a6650b31..0000000000 --- a/graphrag/llm/mock/mock_chat_llm.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A mock ChatLLM that returns the given responses.""" - -from typing_extensions import Unpack - -from graphrag.llm.base import BaseLLM -from graphrag.llm.types import ( - CompletionInput, - CompletionOutput, - LLMInput, - LLMOutput, -) - - -class MockChatLLM( - BaseLLM[ - CompletionInput, - CompletionOutput, - ] -): - """A mock LLM that returns the given responses.""" - - responses: list[str] - i: int = 0 - - def __init__(self, responses: list[str]): - self.i = 0 - self.responses = responses - - def _create_output( - self, - output: CompletionOutput | None, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[CompletionOutput]: - history = kwargs.get("history") or [] - return LLMOutput[CompletionOutput]( - output=output, history=[*history, {"content": output}] - ) - - async def _execute_llm( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> CompletionOutput: - if self.i >= len(self.responses): - msg = f"No more responses, requested {self.i} but only have {len(self.responses)}" - raise ValueError(msg) - response = self.responses[self.i] - self.i += 1 - return response diff --git a/graphrag/llm/mock/mock_completion_llm.py b/graphrag/llm/mock/mock_completion_llm.py deleted file mode 100644 index 7eba4ca7c5..0000000000 --- a/graphrag/llm/mock/mock_completion_llm.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Static Response method definition.""" - -import json -import logging - -from typing_extensions import Unpack - -from graphrag.llm.base import BaseLLM -from graphrag.llm.types import ( - CompletionInput, - CompletionOutput, - LLMInput, - LLMOutput, -) - -log = logging.getLogger(__name__) - - -class MockCompletionLLM( - BaseLLM[ - CompletionInput, - CompletionOutput, - ] -): - """Mock Completion LLM for testing purposes.""" - - def __init__(self, responses: list[str]): - self.responses = responses - self._on_error = None - - async def _execute_llm( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> CompletionOutput: - return self.responses[0] - - async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]): - return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0])) diff --git a/graphrag/llm/openai/__init__.py b/graphrag/llm/openai/__init__.py deleted file mode 100644 index 9478e146d2..0000000000 --- a/graphrag/llm/openai/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""OpenAI LLM implementations.""" - -from .create_openai_client import create_openai_client -from .factories import ( - create_openai_chat_llm, - create_openai_completion_llm, - create_openai_embedding_llm, -) -from .openai_chat_llm import OpenAIChatLLM -from .openai_completion_llm import OpenAICompletionLLM -from .openai_configuration import OpenAIConfiguration -from .openai_embeddings_llm import OpenAIEmbeddingsLLM -from .types import OpenAIClientTypes - -__all__ = [ - "OpenAIChatLLM", - "OpenAIClientTypes", - "OpenAICompletionLLM", - "OpenAIConfiguration", - "OpenAIEmbeddingsLLM", - "create_openai_chat_llm", - "create_openai_client", - "create_openai_completion_llm", - "create_openai_embedding_llm", -] diff --git a/graphrag/llm/openai/_prompts.py b/graphrag/llm/openai/_prompts.py deleted file mode 100644 index 37d9f0fc70..0000000000 --- a/graphrag/llm/openai/_prompts.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Utility prompts for low-level LLM invocations.""" - -JSON_CHECK_PROMPT = """ -You are going to be given a malformed JSON string that threw an error during json.loads. -It probably contains unnecessary escape sequences, or it is missing a comma or colon somewhere. -Your task is to fix this string and return a well-formed JSON string containing a single object. -Eliminate any unnecessary escape sequences. -Only return valid JSON, parseable with json.loads, without commentary. - -# Examples ------------ -Text: {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }} -Output: {{"title": "abc", "summary": "def"}} ------------ -Text: {{"title": "abc", "summary": "def" -Output: {{"title": "abc", "summary": "def"}} ------------ -Text: {{"title': "abc", 'summary": "def" -Output: {{"title": "abc", "summary": "def"}} ------------ -Text: "{{"title": "abc", "summary": "def"}}" -Output: {{"title": "abc", "summary": "def"}} ------------ -Text: [{{"title": "abc", "summary": "def"}}] -Output: [{{"title": "abc", "summary": "def"}}] ------------ -Text: [{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}] -Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] ------------ -Text: ```json\n[{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}]``` -Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] - - -# Real Data -Text: {input_text} -Output:""" diff --git a/graphrag/llm/openai/create_openai_client.py b/graphrag/llm/openai/create_openai_client.py deleted file mode 100644 index 40d15cad0a..0000000000 --- a/graphrag/llm/openai/create_openai_client.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Create OpenAI client instance.""" - -import logging -from functools import cache - -from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from openai import AsyncAzureOpenAI, AsyncOpenAI - -from .openai_configuration import OpenAIConfiguration -from .types import OpenAIClientTypes - -log = logging.getLogger(__name__) - -API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" - - -@cache -def create_openai_client( - configuration: OpenAIConfiguration, azure: bool -) -> OpenAIClientTypes: - """Create a new OpenAI client instance.""" - if azure: - api_base = configuration.api_base - if api_base is None: - raise ValueError(API_BASE_REQUIRED_FOR_AZURE) - - log.info( - "Creating Azure OpenAI client api_base=%s, deployment_name=%s", - api_base, - configuration.deployment_name, - ) - audience = ( - configuration.audience - if configuration.audience - else "https://cognitiveservices.azure.com/.default" - ) - - return AsyncAzureOpenAI( - api_key=configuration.api_key if configuration.api_key else None, - azure_ad_token_provider=get_bearer_token_provider( - DefaultAzureCredential(), audience - ) - if not configuration.api_key - else None, - organization=configuration.organization, - # Azure-Specifics - api_version=configuration.api_version, - azure_endpoint=api_base, - azure_deployment=configuration.deployment_name, - # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here - timeout=configuration.request_timeout or 180.0, - max_retries=0, - ) - - log.info("Creating OpenAI client base_url=%s", configuration.api_base) - return AsyncOpenAI( - api_key=configuration.api_key, - base_url=configuration.api_base, - organization=configuration.organization, - # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here - timeout=configuration.request_timeout or 180.0, - max_retries=0, - ) diff --git a/graphrag/llm/openai/factories.py b/graphrag/llm/openai/factories.py deleted file mode 100644 index e595e2e55b..0000000000 --- a/graphrag/llm/openai/factories.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Factory functions for creating OpenAI LLMs.""" - -import asyncio - -from graphrag.llm.base import CachingLLM, RateLimitingLLM -from graphrag.llm.limiting import LLMLimiter -from graphrag.llm.types import ( - LLM, - CompletionLLM, - EmbeddingLLM, - ErrorHandlerFn, - LLMCache, - LLMInvocationFn, - OnCacheActionFn, -) - -from .json_parsing_llm import JsonParsingLLM -from .openai_chat_llm import OpenAIChatLLM -from .openai_completion_llm import OpenAICompletionLLM -from .openai_configuration import OpenAIConfiguration -from .openai_embeddings_llm import OpenAIEmbeddingsLLM -from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM -from .openai_token_replacing_llm import OpenAITokenReplacingLLM -from .types import OpenAIClientTypes -from .utils import ( - RATE_LIMIT_ERRORS, - RETRYABLE_ERRORS, - get_completion_cache_args, - get_sleep_time_from_error, - get_token_counter, -) - - -def create_openai_chat_llm( - client: OpenAIClientTypes, - config: OpenAIConfiguration, - cache: LLMCache | None = None, - limiter: LLMLimiter | None = None, - semaphore: asyncio.Semaphore | None = None, - on_invoke: LLMInvocationFn | None = None, - on_error: ErrorHandlerFn | None = None, - on_cache_hit: OnCacheActionFn | None = None, - on_cache_miss: OnCacheActionFn | None = None, -) -> CompletionLLM: - """Create an OpenAI chat LLM.""" - operation = "chat" - result = OpenAIChatLLM(client, config) - result.on_error(on_error) - if limiter is not None or semaphore is not None: - result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) - if cache is not None: - result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) - result = OpenAIHistoryTrackingLLM(result) - result = OpenAITokenReplacingLLM(result) - return JsonParsingLLM(result) - - -def create_openai_completion_llm( - client: OpenAIClientTypes, - config: OpenAIConfiguration, - cache: LLMCache | None = None, - limiter: LLMLimiter | None = None, - semaphore: asyncio.Semaphore | None = None, - on_invoke: LLMInvocationFn | None = None, - on_error: ErrorHandlerFn | None = None, - on_cache_hit: OnCacheActionFn | None = None, - on_cache_miss: OnCacheActionFn | None = None, -) -> CompletionLLM: - """Create an OpenAI completion LLM.""" - operation = "completion" - result = OpenAICompletionLLM(client, config) - result.on_error(on_error) - if limiter is not None or semaphore is not None: - result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) - if cache is not None: - result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) - return OpenAITokenReplacingLLM(result) - - -def create_openai_embedding_llm( - client: OpenAIClientTypes, - config: OpenAIConfiguration, - cache: LLMCache | None = None, - limiter: LLMLimiter | None = None, - semaphore: asyncio.Semaphore | None = None, - on_invoke: LLMInvocationFn | None = None, - on_error: ErrorHandlerFn | None = None, - on_cache_hit: OnCacheActionFn | None = None, - on_cache_miss: OnCacheActionFn | None = None, -) -> EmbeddingLLM: - """Create an OpenAI embeddings LLM.""" - operation = "embedding" - result = OpenAIEmbeddingsLLM(client, config) - result.on_error(on_error) - if limiter is not None or semaphore is not None: - result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) - if cache is not None: - result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) - return result - - -def _rate_limited( - delegate: LLM, - config: OpenAIConfiguration, - operation: str, - limiter: LLMLimiter | None, - semaphore: asyncio.Semaphore | None, - on_invoke: LLMInvocationFn | None, -): - result = RateLimitingLLM( - delegate, - config, - operation, - RETRYABLE_ERRORS, - RATE_LIMIT_ERRORS, - limiter, - semaphore, - get_token_counter(config), - get_sleep_time_from_error, - ) - result.on_invoke(on_invoke) - return result - - -def _cached( - delegate: LLM, - config: OpenAIConfiguration, - operation: str, - cache: LLMCache, - on_cache_hit: OnCacheActionFn | None, - on_cache_miss: OnCacheActionFn | None, -): - cache_args = get_completion_cache_args(config) - result = CachingLLM(delegate, cache_args, operation, cache) - result.on_cache_hit(on_cache_hit) - result.on_cache_miss(on_cache_miss) - return result diff --git a/graphrag/llm/openai/json_parsing_llm.py b/graphrag/llm/openai/json_parsing_llm.py deleted file mode 100644 index 009c1da42e..0000000000 --- a/graphrag/llm/openai/json_parsing_llm.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""An LLM that unpacks cached JSON responses.""" - -from typing_extensions import Unpack - -from graphrag.llm.types import ( - LLM, - CompletionInput, - CompletionLLM, - CompletionOutput, - LLMInput, - LLMOutput, -) - -from .utils import try_parse_json_object - - -class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): - """An OpenAI History-Tracking LLM.""" - - _delegate: CompletionLLM - - def __init__(self, delegate: CompletionLLM): - self._delegate = delegate - - async def __call__( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[CompletionOutput]: - """Call the LLM with the input and kwargs.""" - result = await self._delegate(input, **kwargs) - if kwargs.get("json") and result.json is None and result.output is not None: - _, parsed_json = try_parse_json_object(result.output) - result.json = parsed_json - return result diff --git a/graphrag/llm/openai/openai_chat_llm.py b/graphrag/llm/openai/openai_chat_llm.py deleted file mode 100644 index bd821ac661..0000000000 --- a/graphrag/llm/openai/openai_chat_llm.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Chat-based language model.""" - -import logging - -from typing_extensions import Unpack - -from graphrag.llm.base import BaseLLM -from graphrag.llm.types import ( - CompletionInput, - CompletionOutput, - LLMInput, - LLMOutput, -) - -from ._prompts import JSON_CHECK_PROMPT -from .openai_configuration import OpenAIConfiguration -from .types import OpenAIClientTypes -from .utils import ( - get_completion_llm_args, - try_parse_json_object, -) - -log = logging.getLogger(__name__) - -_MAX_GENERATION_RETRIES = 3 -FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" - - -class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]): - """A Chat-based LLM.""" - - _client: OpenAIClientTypes - _configuration: OpenAIConfiguration - - def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): - self.client = client - self.configuration = configuration - - async def _execute_llm( - self, input: CompletionInput, **kwargs: Unpack[LLMInput] - ) -> CompletionOutput | None: - args = get_completion_llm_args( - kwargs.get("model_parameters"), self.configuration - ) - history = kwargs.get("history") or [] - messages = [ - *history, - {"role": "user", "content": input}, - ] - completion = await self.client.chat.completions.create( - messages=messages, **args - ) - return completion.choices[0].message.content - - async def _invoke_json( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[CompletionOutput]: - """Generate JSON output.""" - name = kwargs.get("name") or "unknown" - is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) - - async def generate( - attempt: int | None = None, - ) -> LLMOutput[CompletionOutput]: - call_name = name if attempt is None else f"{name}@{attempt}" - return ( - await self._native_json(input, **{**kwargs, "name": call_name}) - if self.configuration.model_supports_json - else await self._manual_json(input, **{**kwargs, "name": call_name}) - ) - - def is_valid(x: dict | None) -> bool: - return x is not None and is_response_valid(x) - - result = await generate() - retry = 0 - while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: - result = await generate(retry) - retry += 1 - - if is_valid(result.json): - return result - - error_msg = f"{FAILED_TO_CREATE_JSON_ERROR} - Faulty JSON: {result.json!s}" - raise RuntimeError(error_msg) - - async def _native_json( - self, input: CompletionInput, **kwargs: Unpack[LLMInput] - ) -> LLMOutput[CompletionOutput]: - """Generate JSON output using a model's native JSON-output support.""" - result = await self._invoke( - input, - **{ - **kwargs, - "model_parameters": { - **(kwargs.get("model_parameters") or {}), - "response_format": {"type": "json_object"}, - }, - }, - ) - - output, json_output = try_parse_json_object(result.output or "") - - return LLMOutput[CompletionOutput]( - output=output, - json=json_output, - history=result.history, - ) - - async def _manual_json( - self, input: CompletionInput, **kwargs: Unpack[LLMInput] - ) -> LLMOutput[CompletionOutput]: - # Otherwise, clean up the output and try to parse it as json - result = await self._invoke(input, **kwargs) - history = result.history or [] - output, json_output = try_parse_json_object(result.output or "") - if json_output: - return LLMOutput[CompletionOutput]( - output=result.output, json=json_output, history=history - ) - # if not return correct formatted json, retry - log.warning("error parsing llm json, retrying") - - # If cleaned up json is unparsable, use the LLM to reformat it (may throw) - result = await self._try_clean_json_with_llm(output, **kwargs) - output, json_output = try_parse_json_object(result.output or "") - - return LLMOutput[CompletionOutput]( - output=output, - json=json_output, - history=history, - ) - - async def _try_clean_json_with_llm( - self, output: str, **kwargs: Unpack[LLMInput] - ) -> LLMOutput[CompletionOutput]: - name = kwargs.get("name") or "unknown" - return await self._invoke( - JSON_CHECK_PROMPT, - **{ - **kwargs, - "variables": {"input_text": output}, - "name": f"fix_json@{name}", - }, - ) diff --git a/graphrag/llm/openai/openai_completion_llm.py b/graphrag/llm/openai/openai_completion_llm.py deleted file mode 100644 index 74511c02a2..0000000000 --- a/graphrag/llm/openai/openai_completion_llm.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A text-completion based LLM.""" - -import logging - -from typing_extensions import Unpack - -from graphrag.llm.base import BaseLLM -from graphrag.llm.types import ( - CompletionInput, - CompletionOutput, - LLMInput, -) - -from .openai_configuration import OpenAIConfiguration -from .types import OpenAIClientTypes -from .utils import get_completion_llm_args - -log = logging.getLogger(__name__) - - -class OpenAICompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): - """A text-completion based LLM.""" - - _client: OpenAIClientTypes - _configuration: OpenAIConfiguration - - def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): - self.client = client - self.configuration = configuration - - async def _execute_llm( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> CompletionOutput | None: - args = get_completion_llm_args( - kwargs.get("model_parameters"), self.configuration - ) - completion = await self.client.completions.create(prompt=input, **args) - return completion.choices[0].text diff --git a/graphrag/llm/openai/openai_configuration.py b/graphrag/llm/openai/openai_configuration.py deleted file mode 100644 index cbcc54093d..0000000000 --- a/graphrag/llm/openai/openai_configuration.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""OpenAI Configuration class definition.""" - -import json -from collections.abc import Hashable -from typing import Any, cast - -from graphrag.llm.types import LLMConfig - - -def _non_blank(value: str | None) -> str | None: - if value is None: - return None - stripped = value.strip() - return None if stripped == "" else value - - -class OpenAIConfiguration(Hashable, LLMConfig): - """OpenAI Configuration class definition.""" - - # Core Configuration - _api_key: str - _model: str - - _api_base: str | None - _api_version: str | None - _audience: str | None - _deployment_name: str | None - _organization: str | None - _proxy: str | None - - # Operation Configuration - _n: int | None - _temperature: float | None - _frequency_penalty: float | None - _presence_penalty: float | None - _top_p: float | None - _max_tokens: int | None - _response_format: str | None - _logit_bias: dict[str, float] | None - _stop: list[str] | None - - # Retry Logic - _max_retries: int | None - _max_retry_wait: float | None - _request_timeout: float | None - - # The raw configuration object - _raw_config: dict - - # Feature Flags - _model_supports_json: bool | None - - # Custom Configuration - _tokens_per_minute: int | None - _requests_per_minute: int | None - _concurrent_requests: int | None - _encoding_model: str | None - _sleep_on_rate_limit_recommendation: bool | None - - def __init__( - self, - config: dict, - ): - """Init method definition.""" - - def lookup_required(key: str) -> str: - return cast(str, config.get(key)) - - def lookup_str(key: str) -> str | None: - return cast(str | None, config.get(key)) - - def lookup_int(key: str) -> int | None: - result = config.get(key) - if result is None: - return None - return int(cast(int, result)) - - def lookup_float(key: str) -> float | None: - result = config.get(key) - if result is None: - return None - return float(cast(float, result)) - - def lookup_dict(key: str) -> dict | None: - return cast(dict | None, config.get(key)) - - def lookup_list(key: str) -> list | None: - return cast(list | None, config.get(key)) - - def lookup_bool(key: str) -> bool | None: - value = config.get(key) - if isinstance(value, str): - return value.upper() == "TRUE" - if isinstance(value, int): - return value > 0 - return cast(bool | None, config.get(key)) - - self._api_key = lookup_required("api_key") - self._model = lookup_required("model") - self._deployment_name = lookup_str("deployment_name") - self._api_base = lookup_str("api_base") - self._api_version = lookup_str("api_version") - self._audience = lookup_str("audience") - self._organization = lookup_str("organization") - self._proxy = lookup_str("proxy") - self._n = lookup_int("n") - self._temperature = lookup_float("temperature") - self._frequency_penalty = lookup_float("frequency_penalty") - self._presence_penalty = lookup_float("presence_penalty") - self._top_p = lookup_float("top_p") - self._max_tokens = lookup_int("max_tokens") - self._response_format = lookup_str("response_format") - self._logit_bias = lookup_dict("logit_bias") - self._stop = lookup_list("stop") - self._max_retries = lookup_int("max_retries") - self._request_timeout = lookup_float("request_timeout") - self._model_supports_json = lookup_bool("model_supports_json") - self._tokens_per_minute = lookup_int("tokens_per_minute") - self._requests_per_minute = lookup_int("requests_per_minute") - self._concurrent_requests = lookup_int("concurrent_requests") - self._encoding_model = lookup_str("encoding_model") - self._max_retry_wait = lookup_float("max_retry_wait") - self._sleep_on_rate_limit_recommendation = lookup_bool( - "sleep_on_rate_limit_recommendation" - ) - self._raw_config = config - - @property - def api_key(self) -> str: - """API key property definition.""" - return self._api_key - - @property - def model(self) -> str: - """Model property definition.""" - return self._model - - @property - def deployment_name(self) -> str | None: - """Deployment name property definition.""" - return _non_blank(self._deployment_name) - - @property - def api_base(self) -> str | None: - """API base property definition.""" - result = _non_blank(self._api_base) - # Remove trailing slash - return result[:-1] if result and result.endswith("/") else result - - @property - def api_version(self) -> str | None: - """API version property definition.""" - return _non_blank(self._api_version) - - @property - def audience(self) -> str | None: - """API version property definition.""" - return _non_blank(self._audience) - - @property - def organization(self) -> str | None: - """Organization property definition.""" - return _non_blank(self._organization) - - @property - def proxy(self) -> str | None: - """Proxy property definition.""" - return _non_blank(self._proxy) - - @property - def n(self) -> int | None: - """N property definition.""" - return self._n - - @property - def temperature(self) -> float | None: - """Temperature property definition.""" - return self._temperature - - @property - def frequency_penalty(self) -> float | None: - """Frequency penalty property definition.""" - return self._frequency_penalty - - @property - def presence_penalty(self) -> float | None: - """Presence penalty property definition.""" - return self._presence_penalty - - @property - def top_p(self) -> float | None: - """Top p property definition.""" - return self._top_p - - @property - def max_tokens(self) -> int | None: - """Max tokens property definition.""" - return self._max_tokens - - @property - def response_format(self) -> str | None: - """Response format property definition.""" - return _non_blank(self._response_format) - - @property - def logit_bias(self) -> dict[str, float] | None: - """Logit bias property definition.""" - return self._logit_bias - - @property - def stop(self) -> list[str] | None: - """Stop property definition.""" - return self._stop - - @property - def max_retries(self) -> int | None: - """Max retries property definition.""" - return self._max_retries - - @property - def max_retry_wait(self) -> float | None: - """Max retry wait property definition.""" - return self._max_retry_wait - - @property - def request_timeout(self) -> float | None: - """Request timeout property definition.""" - return self._request_timeout - - @property - def model_supports_json(self) -> bool | None: - """Model supports json property definition.""" - return self._model_supports_json - - @property - def tokens_per_minute(self) -> int | None: - """Tokens per minute property definition.""" - return self._tokens_per_minute - - @property - def requests_per_minute(self) -> int | None: - """Requests per minute property definition.""" - return self._requests_per_minute - - @property - def concurrent_requests(self) -> int | None: - """Concurrent requests property definition.""" - return self._concurrent_requests - - @property - def encoding_model(self) -> str | None: - """Encoding model property definition.""" - return _non_blank(self._encoding_model) - - @property - def sleep_on_rate_limit_recommendation(self) -> bool | None: - """Whether to sleep for seconds when recommended by 429 errors (azure-specific).""" - return self._sleep_on_rate_limit_recommendation - - @property - def raw_config(self) -> dict: - """Raw config method definition.""" - return self._raw_config - - def lookup(self, name: str, default_value: Any = None) -> Any: - """Lookup method definition.""" - return self._raw_config.get(name, default_value) - - def __str__(self) -> str: - """Str method definition.""" - return json.dumps(self.raw_config, indent=4) - - def __repr__(self) -> str: - """Repr method definition.""" - return f"OpenAIConfiguration({self._raw_config})" - - def __eq__(self, other: object) -> bool: - """Eq method definition.""" - if not isinstance(other, OpenAIConfiguration): - return False - return self._raw_config == other._raw_config - - def __hash__(self) -> int: - """Hash method definition.""" - return hash(tuple(sorted(self._raw_config.items()))) diff --git a/graphrag/llm/openai/openai_embeddings_llm.py b/graphrag/llm/openai/openai_embeddings_llm.py deleted file mode 100644 index 558afe8437..0000000000 --- a/graphrag/llm/openai/openai_embeddings_llm.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The EmbeddingsLLM class.""" - -from typing_extensions import Unpack - -from graphrag.llm.base import BaseLLM -from graphrag.llm.types import ( - EmbeddingInput, - EmbeddingOutput, - LLMInput, -) - -from .openai_configuration import OpenAIConfiguration -from .types import OpenAIClientTypes - - -class OpenAIEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): - """A text-embedding generator LLM.""" - - _client: OpenAIClientTypes - _configuration: OpenAIConfiguration - - def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): - self.client = client - self.configuration = configuration - - async def _execute_llm( - self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] - ) -> EmbeddingOutput | None: - args = { - "model": self.configuration.model, - **(kwargs.get("model_parameters") or {}), - } - embedding = await self.client.embeddings.create( - input=input, - **args, - ) - return [d.embedding for d in embedding.data] diff --git a/graphrag/llm/openai/openai_history_tracking_llm.py b/graphrag/llm/openai/openai_history_tracking_llm.py deleted file mode 100644 index ab903c2d2a..0000000000 --- a/graphrag/llm/openai/openai_history_tracking_llm.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Chat-based language model.""" - -from typing_extensions import Unpack - -from graphrag.llm.types import ( - LLM, - CompletionInput, - CompletionLLM, - CompletionOutput, - LLMInput, - LLMOutput, -) - - -class OpenAIHistoryTrackingLLM(LLM[CompletionInput, CompletionOutput]): - """An OpenAI History-Tracking LLM.""" - - _delegate: CompletionLLM - - def __init__(self, delegate: CompletionLLM): - self._delegate = delegate - - async def __call__( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[CompletionOutput]: - """Call the LLM.""" - history = kwargs.get("history") or [] - output = await self._delegate(input, **kwargs) - return LLMOutput( - output=output.output, - json=output.json, - history=[ - *history, - {"role": "user", "content": input}, - {"role": "assistant", "content": output.output}, - ], - ) diff --git a/graphrag/llm/openai/openai_token_replacing_llm.py b/graphrag/llm/openai/openai_token_replacing_llm.py deleted file mode 100644 index 7385b84059..0000000000 --- a/graphrag/llm/openai/openai_token_replacing_llm.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Chat-based language model.""" - -from typing_extensions import Unpack - -from graphrag.llm.types import ( - LLM, - CompletionInput, - CompletionLLM, - CompletionOutput, - LLMInput, - LLMOutput, -) - -from .utils import perform_variable_replacements - - -class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]): - """An OpenAI History-Tracking LLM.""" - - _delegate: CompletionLLM - - def __init__(self, delegate: CompletionLLM): - self._delegate = delegate - - async def __call__( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[CompletionOutput]: - """Call the LLM with the input and kwargs.""" - variables = kwargs.get("variables") - history = kwargs.get("history") or [] - input = perform_variable_replacements(input, history, variables) - return await self._delegate(input, **kwargs) diff --git a/graphrag/llm/openai/types.py b/graphrag/llm/openai/types.py deleted file mode 100644 index 4aacf18c1c..0000000000 --- a/graphrag/llm/openai/types.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A base class for OpenAI-based LLMs.""" - -from openai import ( - AsyncAzureOpenAI, - AsyncOpenAI, -) - -OpenAIClientTypes = AsyncOpenAI | AsyncAzureOpenAI diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/openai/utils.py deleted file mode 100644 index 64b7118d9b..0000000000 --- a/graphrag/llm/openai/utils.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Utility functions for the OpenAI API.""" - -import json -import logging -import re -from collections.abc import Callable -from typing import Any - -import tiktoken -from json_repair import repair_json -from openai import ( - APIConnectionError, - InternalServerError, - RateLimitError, -) - -from .openai_configuration import OpenAIConfiguration - -DEFAULT_ENCODING = "cl100k_base" - -_encoders: dict[str, tiktoken.Encoding] = {} - -RETRYABLE_ERRORS: list[type[Exception]] = [ - RateLimitError, - APIConnectionError, - InternalServerError, -] -RATE_LIMIT_ERRORS: list[type[Exception]] = [RateLimitError] - -log = logging.getLogger(__name__) - - -def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: - """Get a function that counts the number of tokens in a string.""" - model = config.encoding_model or "cl100k_base" - enc = _encoders.get(model) - if enc is None: - enc = tiktoken.get_encoding(model) - _encoders[model] = enc - - return lambda s: len(enc.encode(s)) - - -def perform_variable_replacements( - input: str, history: list[dict], variables: dict | None -) -> str: - """Perform variable replacements on the input string and in a chat log.""" - result = input - - def replace_all(input: str) -> str: - result = input - if variables: - for entry in variables: - result = result.replace(f"{{{entry}}}", variables[entry]) - return result - - result = replace_all(result) - for i in range(len(history)): - entry = history[i] - if entry.get("role") == "system": - history[i]["content"] = replace_all(entry.get("content") or "") - - return result - - -def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict: - """Get the cache arguments for a completion LLM.""" - return { - "model": configuration.model, - "temperature": configuration.temperature, - "frequency_penalty": configuration.frequency_penalty, - "presence_penalty": configuration.presence_penalty, - "top_p": configuration.top_p, - "max_tokens": configuration.max_tokens, - "n": configuration.n, - } - - -def get_completion_llm_args( - parameters: dict | None, configuration: OpenAIConfiguration -) -> dict: - """Get the arguments for a completion LLM.""" - return { - **get_completion_cache_args(configuration), - **(parameters or {}), - } - - -def try_parse_json_object(input: str) -> tuple[str, dict]: - """JSON cleaning and formatting utilities.""" - # Sometimes, the LLM returns a json string with some extra description, this function will clean it up. - - result = None - try: - # Try parse first - result = json.loads(input) - except json.JSONDecodeError: - log.info("Warning: Error decoding faulty json, attempting repair") - - if result: - return input, result - - _pattern = r"\{(.*)\}" - _match = re.search(_pattern, input, re.DOTALL) - input = "{" + _match.group(1) + "}" if _match else input - - # Clean up json string. - input = ( - input.replace("{{", "{") - .replace("}}", "}") - .replace('"[{', "[{") - .replace('}]"', "}]") - .replace("\\", " ") - .replace("\\n", " ") - .replace("\n", " ") - .replace("\r", "") - .strip() - ) - - # Remove JSON Markdown Frame - if input.startswith("```json"): - input = input[len("```json") :] - if input.endswith("```"): - input = input[: len(input) - len("```")] - - try: - result = json.loads(input) - except json.JSONDecodeError: - # Fixup potentially malformed json string using json_repair. - input = str(repair_json(json_str=input, return_objects=False)) - - # Generate JSON-string output using best-attempt prompting & parsing techniques. - try: - result = json.loads(input) - except json.JSONDecodeError: - log.exception("error loading json, json=%s", input) - return input, {} - else: - if not isinstance(result, dict): - log.exception("not expected dict type. type=%s:", type(result)) - return input, {} - return input, result - else: - return input, result - - -def get_sleep_time_from_error(e: Any) -> float: - """Extract the sleep time value from a RateLimitError. This is usually only available in Azure.""" - sleep_time = 0.0 - if isinstance(e, RateLimitError) and _please_retry_after in str(e): - # could be second or seconds - sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0]) - - return sleep_time - - -_please_retry_after = "Please retry after " diff --git a/graphrag/llm/types/__init__.py b/graphrag/llm/types/__init__.py deleted file mode 100644 index c8277661d5..0000000000 --- a/graphrag/llm/types/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Typings.""" - -from .llm import LLM -from .llm_cache import LLMCache -from .llm_callbacks import ( - ErrorHandlerFn, - IsResponseValidFn, - LLMInvocationFn, - OnCacheActionFn, -) -from .llm_config import LLMConfig -from .llm_invocation_result import LLMInvocationResult -from .llm_io import ( - LLMInput, - LLMOutput, -) -from .llm_types import ( - CompletionInput, - CompletionLLM, - CompletionOutput, - EmbeddingInput, - EmbeddingLLM, - EmbeddingOutput, -) - -__all__ = [ - "LLM", - "CompletionInput", - "CompletionLLM", - "CompletionOutput", - "EmbeddingInput", - "EmbeddingLLM", - "EmbeddingOutput", - "ErrorHandlerFn", - "IsResponseValidFn", - "LLMCache", - "LLMConfig", - "LLMInput", - "LLMInvocationFn", - "LLMInvocationResult", - "LLMOutput", - "OnCacheActionFn", -] diff --git a/graphrag/llm/types/llm.py b/graphrag/llm/types/llm.py deleted file mode 100644 index fd8407e50e..0000000000 --- a/graphrag/llm/types/llm.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Types.""" - -from typing import Generic, Protocol, TypeVar - -from typing_extensions import Unpack - -from .llm_io import ( - LLMInput, - LLMOutput, -) - -TIn = TypeVar("TIn", contravariant=True) -TOut = TypeVar("TOut") - - -class LLM(Protocol, Generic[TIn, TOut]): - """LLM Protocol definition.""" - - async def __call__( - self, - input: TIn, - **kwargs: Unpack[LLMInput], - ) -> LLMOutput[TOut]: - """Invoke the LLM, treating the LLM as a function.""" - ... diff --git a/graphrag/llm/types/llm_cache.py b/graphrag/llm/types/llm_cache.py deleted file mode 100644 index 952b8d346d..0000000000 --- a/graphrag/llm/types/llm_cache.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Typing definitions for the OpenAI DataShaper package.""" - -from typing import Any, Protocol - - -class LLMCache(Protocol): - """LLM Cache interface.""" - - async def has(self, key: str) -> bool: - """Check if the cache has a value.""" - ... - - async def get(self, key: str) -> Any | None: - """Retrieve a value from the cache.""" - ... - - async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: - """Write a value into the cache.""" - ... diff --git a/graphrag/llm/types/llm_callbacks.py b/graphrag/llm/types/llm_callbacks.py deleted file mode 100644 index dc06dbff06..0000000000 --- a/graphrag/llm/types/llm_callbacks.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Typing definitions for the OpenAI DataShaper package.""" - -from collections.abc import Callable - -from .llm_invocation_result import LLMInvocationResult - -ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] -"""Error handler function type definition.""" - -LLMInvocationFn = Callable[[LLMInvocationResult], None] -"""Handler for LLM invocation results""" - -OnCacheActionFn = Callable[[str, str | None], None] -"""Handler for cache hits""" - -IsResponseValidFn = Callable[[dict], bool] -"""A function that checks if an LLM response is valid.""" diff --git a/graphrag/llm/types/llm_config.py b/graphrag/llm/types/llm_config.py deleted file mode 100644 index cd7ec255b2..0000000000 --- a/graphrag/llm/types/llm_config.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Configuration Protocol definition.""" - -from typing import Protocol - - -class LLMConfig(Protocol): - """LLM Configuration Protocol definition.""" - - @property - def max_retries(self) -> int | None: - """Get the maximum number of retries.""" - ... - - @property - def max_retry_wait(self) -> float | None: - """Get the maximum retry wait time.""" - ... - - @property - def sleep_on_rate_limit_recommendation(self) -> bool | None: - """Get whether to sleep on rate limit recommendation.""" - ... - - @property - def tokens_per_minute(self) -> int | None: - """Get the number of tokens per minute.""" - ... - - @property - def requests_per_minute(self) -> int | None: - """Get the number of requests per minute.""" - ... diff --git a/graphrag/llm/types/llm_invocation_result.py b/graphrag/llm/types/llm_invocation_result.py deleted file mode 100644 index 1769aeb96d..0000000000 --- a/graphrag/llm/types/llm_invocation_result.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Typing definitions for the OpenAI DataShaper package.""" - -from dataclasses import dataclass -from typing import Generic, TypeVar - -T = TypeVar("T") - - -@dataclass -class LLMInvocationResult(Generic[T]): - """The result of an LLM invocation.""" - - result: T | None - """The result of the LLM invocation.""" - - name: str - """The operation name of the result""" - - num_retries: int - """The number of retries the invocation took.""" - - total_time: float - """The total time of the LLM invocation.""" - - call_times: list[float] - """The network times of individual invocations.""" - - input_tokens: int - """The number of input tokens.""" - - output_tokens: int - """The number of output tokens.""" diff --git a/graphrag/llm/types/llm_io.py b/graphrag/llm/types/llm_io.py deleted file mode 100644 index 256f3c8ce8..0000000000 --- a/graphrag/llm/types/llm_io.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Types.""" - -from dataclasses import dataclass, field -from typing import Generic, TypeVar - -from typing_extensions import NotRequired, TypedDict - -from .llm_callbacks import IsResponseValidFn - - -class LLMInput(TypedDict): - """The input of an LLM invocation.""" - - name: NotRequired[str] - """The name of the LLM invocation, if available.""" - - json: NotRequired[bool] - """If true, will attempt to elicit JSON from the LLM. Parsed JSON will be returned in the `json_output` field.""" - - is_response_valid: NotRequired[IsResponseValidFn] - """A function that checks if an LLM response is valid. Only valid if `json=True`.""" - - variables: NotRequired[dict] - """The variable replacements to use in the prompt.""" - - history: NotRequired[list[dict] | None] - """The history of the LLM invocation, if available (e.g. chat mode)""" - - model_parameters: NotRequired[dict] - """Additional model parameters to use in the LLM invocation.""" - - -T = TypeVar("T") - - -@dataclass -class LLMOutput(Generic[T]): - """The output of an LLM invocation.""" - - output: T | None - """The output of the LLM invocation.""" - - json: dict | None = field(default=None) - """The JSON output from the LLM, if available.""" - - history: list[dict] | None = field(default=None) - """The history of the LLM invocation, if available (e.g. chat mode)""" diff --git a/graphrag/llm/types/llm_types.py b/graphrag/llm/types/llm_types.py deleted file mode 100644 index 7ae76ef9be..0000000000 --- a/graphrag/llm/types/llm_types.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""LLM Types.""" - -from typing import TypeAlias - -from .llm import LLM - -EmbeddingInput: TypeAlias = list[str] -EmbeddingOutput: TypeAlias = list[list[float]] -CompletionInput: TypeAlias = str -CompletionOutput: TypeAlias = str - -EmbeddingLLM: TypeAlias = LLM[EmbeddingInput, EmbeddingOutput] -CompletionLLM: TypeAlias = LLM[CompletionInput, CompletionOutput] diff --git a/graphrag/prompt_tune/generator/community_report_rating.py b/graphrag/prompt_tune/generator/community_report_rating.py index 23d7cc6832..4c8c18f93c 100644 --- a/graphrag/prompt_tune/generator/community_report_rating.py +++ b/graphrag/prompt_tune/generator/community_report_rating.py @@ -3,14 +3,15 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.prompt.community_report_rating import ( GENERATE_REPORT_RATING_PROMPT, ) async def generate_community_report_rating( - llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] + llm: ChatLLM, domain: str, persona: str, docs: str | list[str] ) -> str: """Generate an LLM persona to use for GraphRAG prompts. @@ -32,4 +33,4 @@ async def generate_community_report_rating( response = await llm(domain_prompt) - return str(response.output).strip() + return str(response.output.content).strip() diff --git a/graphrag/prompt_tune/generator/community_reporter_role.py b/graphrag/prompt_tune/generator/community_reporter_role.py index f16a6c3dd4..d2e85f813f 100644 --- a/graphrag/prompt_tune/generator/community_reporter_role.py +++ b/graphrag/prompt_tune/generator/community_reporter_role.py @@ -3,14 +3,15 @@ """Generate a community reporter role for community summarization.""" -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.prompt.community_reporter_role import ( GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT, ) async def generate_community_reporter_role( - llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] + llm: ChatLLM, domain: str, persona: str, docs: str | list[str] ) -> str: """Generate an LLM persona to use for GraphRAG prompts. @@ -32,4 +33,4 @@ async def generate_community_reporter_role( response = await llm(domain_prompt) - return str(response.output) + return str(response.output.content) diff --git a/graphrag/prompt_tune/generator/domain.py b/graphrag/prompt_tune/generator/domain.py index 49c698d1b4..2e129c0cd2 100644 --- a/graphrag/prompt_tune/generator/domain.py +++ b/graphrag/prompt_tune/generator/domain.py @@ -3,11 +3,12 @@ """Domain generation for GraphRAG prompts.""" -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT -async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str: +async def generate_domain(llm: ChatLLM, docs: str | list[str]) -> str: """Generate an LLM persona to use for GraphRAG prompts. Parameters @@ -24,4 +25,4 @@ async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str: response = await llm(domain_prompt) - return str(response.output) + return str(response.output.content) diff --git a/graphrag/prompt_tune/generator/entity_relationship.py b/graphrag/prompt_tune/generator/entity_relationship.py index f8862bd6ef..38873bd5a7 100644 --- a/graphrag/prompt_tune/generator/entity_relationship.py +++ b/graphrag/prompt_tune/generator/entity_relationship.py @@ -6,7 +6,8 @@ import asyncio import json -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.prompt.entity_relationship import ( ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, ENTITY_RELATIONSHIPS_GENERATION_PROMPT, @@ -17,7 +18,7 @@ async def generate_entity_relationship_examples( - llm: CompletionLLM, + llm: ChatLLM, persona: str, entity_types: str | list[str] | None, docs: str | list[str], @@ -30,7 +31,7 @@ async def generate_entity_relationship_examples( on the json_mode parameter. """ docs_list = [docs] if isinstance(docs, str) else docs - history = [{"role": "system", "content": persona}] + history = [{"content": persona, "role": "system"}] if entity_types: entity_types_str = ( @@ -62,6 +63,6 @@ async def generate_entity_relationship_examples( responses = await asyncio.gather(*tasks) return [ - json.dumps(response.json or "") if json_mode else str(response.output) + json.dumps(response.json or "") if json_mode else str(response.output.content) for response in responses ] diff --git a/graphrag/prompt_tune/generator/entity_types.py b/graphrag/prompt_tune/generator/entity_types.py index 51ac0020e0..b64f1a68fa 100644 --- a/graphrag/prompt_tune/generator/entity_types.py +++ b/graphrag/prompt_tune/generator/entity_types.py @@ -3,7 +3,9 @@ """Entity type generation module for fine-tuning.""" -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM +from pydantic import BaseModel + from graphrag.prompt_tune.defaults import DEFAULT_TASK from graphrag.prompt_tune.prompt.entity_types import ( ENTITY_TYPE_GENERATION_JSON_PROMPT, @@ -11,8 +13,14 @@ ) +class EntityTypesResponse(BaseModel): + """Entity types response model.""" + + entity_types: list[str] + + async def generate_entity_types( - llm: CompletionLLM, + llm: ChatLLM, domain: str, persona: str, docs: str | list[str], @@ -37,9 +45,12 @@ async def generate_entity_types( history = [{"role": "system", "content": persona}] - response = await llm(entity_types_prompt, history=history, json=json_mode) - if json_mode: - return (response.json or {}).get("entity_types", []) + response = await llm( + entity_types_prompt, history=history, json_model=EntityTypesResponse + ) + model = response.parsed_json + return model.entity_types if model else [] - return str(response.output) + response = await llm(entity_types_prompt, history=history, json=json_mode) + return str(response.output.content) diff --git a/graphrag/prompt_tune/generator/language.py b/graphrag/prompt_tune/generator/language.py index d803df9c54..7327e12531 100644 --- a/graphrag/prompt_tune/generator/language.py +++ b/graphrag/prompt_tune/generator/language.py @@ -3,11 +3,12 @@ """Language detection for GraphRAG prompts.""" -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT -async def detect_language(llm: CompletionLLM, docs: str | list[str]) -> str: +async def detect_language(llm: ChatLLM, docs: str | list[str]) -> str: """Detect input language to use for GraphRAG prompts. Parameters @@ -24,4 +25,4 @@ async def detect_language(llm: CompletionLLM, docs: str | list[str]) -> str: response = await llm(language_prompt) - return str(response.output) + return str(response.output.content) diff --git a/graphrag/prompt_tune/generator/persona.py b/graphrag/prompt_tune/generator/persona.py index c66cc4a717..d5fc5e59b3 100644 --- a/graphrag/prompt_tune/generator/persona.py +++ b/graphrag/prompt_tune/generator/persona.py @@ -3,14 +3,13 @@ """Persona generating module for fine-tuning GraphRAG prompts.""" -from graphrag.llm.types.llm_types import CompletionLLM +from fnllm import ChatLLM + from graphrag.prompt_tune.defaults import DEFAULT_TASK from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT -async def generate_persona( - llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK -) -> str: +async def generate_persona(llm: ChatLLM, domain: str, task: str = DEFAULT_TASK) -> str: """Generate an LLM persona to use for GraphRAG prompts. Parameters @@ -24,4 +23,4 @@ async def generate_persona( response = await llm(persona_prompt) - return str(response.output) + return str(response.output.content) diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index f7e29ea42d..bbf9c5b17d 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -6,13 +6,15 @@ import numpy as np import pandas as pd from datashaper import NoopVerbCallbacks +from fnllm import ChatLLM +from pydantic import TypeAdapter import graphrag.config.defaults as defs from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.input.factory import create_input from graphrag.index.llm.load_llm import load_llm_embeddings from graphrag.index.operations.chunk_text import chunk_text -from graphrag.llm.types.llm_types import EmbeddingLLM from graphrag.logging.base import ProgressReporter from graphrag.prompt_tune.defaults import ( MIN_CHUNK_OVERLAP, @@ -25,13 +27,13 @@ async def _embed_chunks( text_chunks: pd.DataFrame, - embedding_llm: EmbeddingLLM, + embedding_llm: ChatLLM, n_subset_max: int = N_SUBSET_MAX, ) -> tuple[pd.DataFrame, np.ndarray]: """Convert text chunks into dense text embeddings.""" sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks))) embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist()) - return text_chunks, np.array(embeddings.output) + return text_chunks, np.array(embeddings.output.embeddings) def _sample_chunks_from_embeddings( @@ -58,6 +60,10 @@ async def load_docs_in_chunks( k: int = K, ) -> list[str]: """Load docs into chunks for generating prompts.""" + llm_config = TypeAdapter(LLMParameters).validate_python( + config.embeddings.resolved_strategy()["llm"] + ) + dataset = await create_input(config.input, reporter, root) # covert to text units @@ -91,11 +97,10 @@ async def load_docs_in_chunks( msg = "k must be an integer > 0" raise ValueError(msg) embedding_llm = load_llm_embeddings( - name="prompt_tuning_embeddings", - llm_type=config.embeddings.resolved_strategy()["llm"]["type"], + "prompt_tuning_embeddings", + llm_config, callbacks=NoopVerbCallbacks(), cache=None, - llm_config=config.embeddings.resolved_strategy()["llm"], ) chunks_df, embeddings = await _embed_chunks( diff --git a/graphrag/query/context_builder/rate_relevancy.py b/graphrag/query/context_builder/rate_relevancy.py index 5ae67f38c2..42fa723bb7 100644 --- a/graphrag/query/context_builder/rate_relevancy.py +++ b/graphrag/query/context_builder/rate_relevancy.py @@ -11,10 +11,9 @@ import numpy as np import tiktoken -from graphrag.llm.openai.utils import try_parse_json_object from graphrag.query.context_builder.rate_prompt import RATE_QUERY from graphrag.query.llm.base import BaseLLM -from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object log = logging.getLogger(__name__) diff --git a/graphrag/query/llm/get_client.py b/graphrag/query/llm/get_client.py index 5b9dbfbbc2..b32a1497b4 100644 --- a/graphrag/query/llm/get_client.py +++ b/graphrag/query/llm/get_client.py @@ -14,10 +14,7 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: """Get the LLM client.""" - is_azure_client = ( - config.llm.type == LLMType.AzureOpenAIChat - or config.llm.type == LLMType.AzureOpenAI - ) + is_azure_client = config.llm.type == LLMType.AzureOpenAIChat debug_llm_key = config.llm.api_key or "" llm_debug_info = { **config.llm.model_dump(), diff --git a/graphrag/query/llm/text_utils.py b/graphrag/query/llm/text_utils.py index c9c6fbc5ac..2d5e0c371e 100644 --- a/graphrag/query/llm/text_utils.py +++ b/graphrag/query/llm/text_utils.py @@ -3,10 +3,16 @@ """Text Utilities for LLM.""" +import json +import logging +import re from collections.abc import Iterator from itertools import islice import tiktoken +from json_repair import repair_json + +log = logging.getLogger(__name__) def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int: @@ -40,3 +46,61 @@ def chunk_text( tokens = token_encoder.encode(text) # type: ignore chunk_iterator = batched(iter(tokens), max_tokens) yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator) + + +def try_parse_json_object(input: str) -> tuple[str, dict]: + """JSON cleaning and formatting utilities.""" + # Sometimes, the LLM returns a json string with some extra description, this function will clean it up. + + result = None + try: + # Try parse first + result = json.loads(input) + except json.JSONDecodeError: + log.info("Warning: Error decoding faulty json, attempting repair") + + if result: + return input, result + + _pattern = r"\{(.*)\}" + _match = re.search(_pattern, input, re.DOTALL) + input = "{" + _match.group(1) + "}" if _match else input + + # Clean up json string. + input = ( + input.replace("{{", "{") + .replace("}}", "}") + .replace('"[{', "[{") + .replace('}]"', "}]") + .replace("\\", " ") + .replace("\\n", " ") + .replace("\n", " ") + .replace("\r", "") + .strip() + ) + + # Remove JSON Markdown Frame + if input.startswith("```json"): + input = input[len("```json") :] + if input.endswith("```"): + input = input[: len(input) - len("```")] + + try: + result = json.loads(input) + except json.JSONDecodeError: + # Fixup potentially malformed json string using json_repair. + input = str(repair_json(json_str=input, return_objects=False)) + + # Generate JSON-string output using best-attempt prompting & parsing techniques. + try: + result = json.loads(input) + except json.JSONDecodeError: + log.exception("error loading json, json=%s", input) + return input, {} + else: + if not isinstance(result, dict): + log.exception("not expected dict type. type=%s:", type(result)) + return input, {} + return input, result + else: + return input, result diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index f19b8726b6..d0f0a7c574 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -15,7 +15,6 @@ import tiktoken from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback -from graphrag.llm.openai.utils import try_parse_json_object from graphrag.prompts.query.global_search_knowledge_system_prompt import ( GENERAL_KNOWLEDGE_INSTRUCTION, ) @@ -31,7 +30,7 @@ ConversationHistory, ) from graphrag.query.llm.base import BaseLLM -from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object from graphrag.query.structured_search.base import BaseSearch, SearchResult DEFAULT_MAP_LLM_PARAMS = { @@ -355,10 +354,10 @@ async def _reduce_response( for point in filtered_key_points: formatted_response_data = [] formatted_response_data.append( - f'----Analyst {point["analyst"] + 1}----' + f"----Analyst {point['analyst'] + 1}----" ) formatted_response_data.append( - f'Importance Score: {point["score"]}' # type: ignore + f"Importance Score: {point['score']}" # type: ignore ) formatted_response_data.append(point["answer"]) # type: ignore formatted_response_text = "\n".join(formatted_response_data) @@ -456,8 +455,8 @@ async def _stream_reduce_response( total_tokens = 0 for point in filtered_key_points: formatted_response_data = [ - f'----Analyst {point["analyst"] + 1}----', - f'Importance Score: {point["score"]}', + f"----Analyst {point['analyst'] + 1}----", + f"Importance Score: {point['score']}", point["answer"], ] formatted_response_text = "\n".join(formatted_response_data) diff --git a/poetry.lock b/poetry.lock index 47985d7218..20f682466b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1072,6 +1072,28 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fnllm" +version = "0.0.8" +description = "A function-based LLM protocol and wrapper." +optional = false +python-versions = ">=3.10" +files = [ + {file = "fnllm-0.0.8-py3-none-any.whl", hash = "sha256:f8dbfa97ad6d7102db06c841348ffe57009f1e174f491cd5c9284be73aec4b33"}, + {file = "fnllm-0.0.8.tar.gz", hash = "sha256:67c1fb9c00680fae0b9348f07f9b382c194616e33be214080a6ef07d539a49c3"}, +] + +[package.dependencies] +aiolimiter = ">=1.1.0" +httpx = ">=0.27.0" +json-repair = ">=0.30.0" +pydantic = ">=2.8.2" +tenacity = ">=8.5.0" + +[package.extras] +azure = ["azure-identity (>=1.17.1)", "azure-storage-blob (>=12.20.0)"] +openai = ["openai (>=1.35.12)", "tiktoken (>=0.7.0)"] + [[package]] name = "fonttools" version = "4.54.1" @@ -2506,13 +2528,13 @@ pygments = ">2.12.0" [[package]] name = "mkdocs-material" -version = "9.5.43" +version = "9.5.44" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.5.43-py3-none-any.whl", hash = "sha256:4aae0664c456fd12837a3192e0225c17960ba8bf55d7f0a7daef7e4b0b914a34"}, - {file = "mkdocs_material-9.5.43.tar.gz", hash = "sha256:83be7ff30b65a1e4930dfa4ab911e75780a3afc9583d162692e434581cb46979"}, + {file = "mkdocs_material-9.5.44-py3-none-any.whl", hash = "sha256:47015f9c167d58a5ff5e682da37441fc4d66a1c79334bfc08d774763cacf69ca"}, + {file = "mkdocs_material-9.5.44.tar.gz", hash = "sha256:f3a6c968e524166b3f3ed1fb97d3ed3e0091183b0545cedf7156a2a6804c56c0"}, ] [package.dependencies] @@ -2860,13 +2882,13 @@ files = [ [[package]] name = "openai" -version = "1.54.0" +version = "1.54.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" files = [ - {file = "openai-1.54.0-py3-none-any.whl", hash = "sha256:24ed8874b56e919f0fbb80b7136c3fb022dc82ce9f5f21579b7b280ea4bba249"}, - {file = "openai-1.54.0.tar.gz", hash = "sha256:df2a84384314165b706722a7ac8988dc33eba20dd7fc3b939d138110e608b1ce"}, + {file = "openai-1.54.1-py3-none-any.whl", hash = "sha256:3cb49ccb6bfdc724ad01cc397d323ef8314fc7d45e19e9de2afdd6484a533324"}, + {file = "openai-1.54.1.tar.gz", hash = "sha256:5b832bf82002ba8c4f6e5e25c1c0f5d468c22f043711544c716eaffdb30dd6f1"}, ] [package.dependencies] @@ -3712,13 +3734,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pyright" -version = "1.1.387" +version = "1.1.388" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.387-py3-none-any.whl", hash = "sha256:6a1f495a261a72e12ad17e20d1ae3df4511223c773b19407cfa006229b1b08a5"}, - {file = "pyright-1.1.387.tar.gz", hash = "sha256:577de60224f7fe36505d5b181231e3a395d427b7873be0bbcaa962a29ea93a60"}, + {file = "pyright-1.1.388-py3-none-any.whl", hash = "sha256:c7068e9f2c23539c6ac35fc9efac6c6c1b9aa5a0ce97a9a8a6cf0090d7cbf84c"}, + {file = "pyright-1.1.388.tar.gz", hash = "sha256:0166d19b716b77fd2d9055de29f71d844874dbc6b9d3472ccd22df91db3dfa34"}, ] [package.dependencies] @@ -3770,6 +3792,21 @@ pytest = ">=8.2,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-dotenv" +version = "0.5.2" +description = "A py.test plugin that parses environment files before running tests" +optional = false +python-versions = "*" +files = [ + {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, + {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, +] + +[package.dependencies] +pytest = ">=5.0.0" +python-dotenv = ">=0.9.1" + [[package]] name = "pytest-timeout" version = "2.3.1" @@ -4273,114 +4310,101 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.20.1" +version = "0.21.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "rpds_py-0.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a649dfd735fff086e8a9d0503a9f0c7d01b7912a333c7ae77e1515c08c146dad"}, - {file = "rpds_py-0.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f16bc1334853e91ddaaa1217045dd7be166170beec337576818461268a3de67f"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14511a539afee6f9ab492b543060c7491c99924314977a55c98bfa2ee29ce78c"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3ccb8ac2d3c71cda472b75af42818981bdacf48d2e21c36331b50b4f16930163"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c142b88039b92e7e0cb2552e8967077e3179b22359e945574f5e2764c3953dcf"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f19169781dddae7478a32301b499b2858bc52fc45a112955e798ee307e294977"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13c56de6518e14b9bf6edde23c4c39dac5b48dcf04160ea7bce8fca8397cdf86"}, - {file = "rpds_py-0.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:925d176a549f4832c6f69fa6026071294ab5910e82a0fe6c6228fce17b0706bd"}, - {file = "rpds_py-0.20.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:78f0b6877bfce7a3d1ff150391354a410c55d3cdce386f862926a4958ad5ab7e"}, - {file = "rpds_py-0.20.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3dd645e2b0dcb0fd05bf58e2e54c13875847687d0b71941ad2e757e5d89d4356"}, - {file = "rpds_py-0.20.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4f676e21db2f8c72ff0936f895271e7a700aa1f8d31b40e4e43442ba94973899"}, - {file = "rpds_py-0.20.1-cp310-none-win32.whl", hash = "sha256:648386ddd1e19b4a6abab69139b002bc49ebf065b596119f8f37c38e9ecee8ff"}, - {file = "rpds_py-0.20.1-cp310-none-win_amd64.whl", hash = "sha256:d9ecb51120de61e4604650666d1f2b68444d46ae18fd492245a08f53ad2b7711"}, - {file = "rpds_py-0.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:762703bdd2b30983c1d9e62b4c88664df4a8a4d5ec0e9253b0231171f18f6d75"}, - {file = "rpds_py-0.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0b581f47257a9fce535c4567782a8976002d6b8afa2c39ff616edf87cbeff712"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:842c19a6ce894493563c3bd00d81d5100e8e57d70209e84d5491940fdb8b9e3a"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42cbde7789f5c0bcd6816cb29808e36c01b960fb5d29f11e052215aa85497c93"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c8e9340ce5a52f95fa7d3b552b35c7e8f3874d74a03a8a69279fd5fca5dc751"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ba6f89cac95c0900d932c9efb7f0fb6ca47f6687feec41abcb1bd5e2bd45535"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a916087371afd9648e1962e67403c53f9c49ca47b9680adbeef79da3a7811b0"}, - {file = "rpds_py-0.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:200a23239781f46149e6a415f1e870c5ef1e712939fe8fa63035cd053ac2638e"}, - {file = "rpds_py-0.20.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:58b1d5dd591973d426cbb2da5e27ba0339209832b2f3315928c9790e13f159e8"}, - {file = "rpds_py-0.20.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6b73c67850ca7cae0f6c56f71e356d7e9fa25958d3e18a64927c2d930859b8e4"}, - {file = "rpds_py-0.20.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d8761c3c891cc51e90bc9926d6d2f59b27beaf86c74622c8979380a29cc23ac3"}, - {file = "rpds_py-0.20.1-cp311-none-win32.whl", hash = "sha256:cd945871335a639275eee904caef90041568ce3b42f402c6959b460d25ae8732"}, - {file = "rpds_py-0.20.1-cp311-none-win_amd64.whl", hash = "sha256:7e21b7031e17c6b0e445f42ccc77f79a97e2687023c5746bfb7a9e45e0921b84"}, - {file = "rpds_py-0.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:36785be22066966a27348444b40389f8444671630063edfb1a2eb04318721e17"}, - {file = "rpds_py-0.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:142c0a5124d9bd0e2976089484af5c74f47bd3298f2ed651ef54ea728d2ea42c"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbddc10776ca7ebf2a299c41a4dde8ea0d8e3547bfd731cb87af2e8f5bf8962d"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:15a842bb369e00295392e7ce192de9dcbf136954614124a667f9f9f17d6a216f"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be5ef2f1fc586a7372bfc355986226484e06d1dc4f9402539872c8bb99e34b01"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbcf360c9e3399b056a238523146ea77eeb2a596ce263b8814c900263e46031a"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecd27a66740ffd621d20b9a2f2b5ee4129a56e27bfb9458a3bcc2e45794c96cb"}, - {file = "rpds_py-0.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0b937b2a1988f184a3e9e577adaa8aede21ec0b38320d6009e02bd026db04fa"}, - {file = "rpds_py-0.20.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6889469bfdc1eddf489729b471303739bf04555bb151fe8875931f8564309afc"}, - {file = "rpds_py-0.20.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:19b73643c802f4eaf13d97f7855d0fb527fbc92ab7013c4ad0e13a6ae0ed23bd"}, - {file = "rpds_py-0.20.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3c6afcf2338e7f374e8edc765c79fbcb4061d02b15dd5f8f314a4af2bdc7feb5"}, - {file = "rpds_py-0.20.1-cp312-none-win32.whl", hash = "sha256:dc73505153798c6f74854aba69cc75953888cf9866465196889c7cdd351e720c"}, - {file = "rpds_py-0.20.1-cp312-none-win_amd64.whl", hash = "sha256:8bbe951244a838a51289ee53a6bae3a07f26d4e179b96fc7ddd3301caf0518eb"}, - {file = "rpds_py-0.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:6ca91093a4a8da4afae7fe6a222c3b53ee4eef433ebfee4d54978a103435159e"}, - {file = "rpds_py-0.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b9c2fe36d1f758b28121bef29ed1dee9b7a2453e997528e7d1ac99b94892527c"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f009c69bc8c53db5dfab72ac760895dc1f2bc1b62ab7408b253c8d1ec52459fc"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6740a3e8d43a32629bb9b009017ea5b9e713b7210ba48ac8d4cb6d99d86c8ee8"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32b922e13d4c0080d03e7b62991ad7f5007d9cd74e239c4b16bc85ae8b70252d"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe00a9057d100e69b4ae4a094203a708d65b0f345ed546fdef86498bf5390982"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49fe9b04b6fa685bd39237d45fad89ba19e9163a1ccaa16611a812e682913496"}, - {file = "rpds_py-0.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aa7ac11e294304e615b43f8c441fee5d40094275ed7311f3420d805fde9b07b4"}, - {file = "rpds_py-0.20.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aa97af1558a9bef4025f8f5d8c60d712e0a3b13a2fe875511defc6ee77a1ab7"}, - {file = "rpds_py-0.20.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:483b29f6f7ffa6af845107d4efe2e3fa8fb2693de8657bc1849f674296ff6a5a"}, - {file = "rpds_py-0.20.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37fe0f12aebb6a0e3e17bb4cd356b1286d2d18d2e93b2d39fe647138458b4bcb"}, - {file = "rpds_py-0.20.1-cp313-none-win32.whl", hash = "sha256:a624cc00ef2158e04188df5e3016385b9353638139a06fb77057b3498f794782"}, - {file = "rpds_py-0.20.1-cp313-none-win_amd64.whl", hash = "sha256:b71b8666eeea69d6363248822078c075bac6ed135faa9216aa85f295ff009b1e"}, - {file = "rpds_py-0.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5b48e790e0355865197ad0aca8cde3d8ede347831e1959e158369eb3493d2191"}, - {file = "rpds_py-0.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3e310838a5801795207c66c73ea903deda321e6146d6f282e85fa7e3e4854804"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249280b870e6a42c0d972339e9cc22ee98730a99cd7f2f727549af80dd5a963"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e79059d67bea28b53d255c1437b25391653263f0e69cd7dec170d778fdbca95e"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b431c777c9653e569986ecf69ff4a5dba281cded16043d348bf9ba505486f36"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da584ff96ec95e97925174eb8237e32f626e7a1a97888cdd27ee2f1f24dd0ad8"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a0629ec053fc013808a85178524e3cb63a61dbc35b22499870194a63578fb9"}, - {file = "rpds_py-0.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fbf15aff64a163db29a91ed0868af181d6f68ec1a3a7d5afcfe4501252840bad"}, - {file = "rpds_py-0.20.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:07924c1b938798797d60c6308fa8ad3b3f0201802f82e4a2c41bb3fafb44cc28"}, - {file = "rpds_py-0.20.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4a5a844f68776a7715ecb30843b453f07ac89bad393431efbf7accca3ef599c1"}, - {file = "rpds_py-0.20.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:518d2ca43c358929bf08f9079b617f1c2ca6e8848f83c1225c88caeac46e6cbc"}, - {file = "rpds_py-0.20.1-cp38-none-win32.whl", hash = "sha256:3aea7eed3e55119635a74bbeb80b35e776bafccb70d97e8ff838816c124539f1"}, - {file = "rpds_py-0.20.1-cp38-none-win_amd64.whl", hash = "sha256:7dca7081e9a0c3b6490a145593f6fe3173a94197f2cb9891183ef75e9d64c425"}, - {file = "rpds_py-0.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b41b6321805c472f66990c2849e152aff7bc359eb92f781e3f606609eac877ad"}, - {file = "rpds_py-0.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a90c373ea2975519b58dece25853dbcb9779b05cc46b4819cb1917e3b3215b6"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16d4477bcb9fbbd7b5b0e4a5d9b493e42026c0bf1f06f723a9353f5153e75d30"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:84b8382a90539910b53a6307f7c35697bc7e6ffb25d9c1d4e998a13e842a5e83"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4888e117dd41b9d34194d9e31631af70d3d526efc363085e3089ab1a62c32ed1"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5265505b3d61a0f56618c9b941dc54dc334dc6e660f1592d112cd103d914a6db"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e75ba609dba23f2c95b776efb9dd3f0b78a76a151e96f96cc5b6b1b0004de66f"}, - {file = "rpds_py-0.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1791ff70bc975b098fe6ecf04356a10e9e2bd7dc21fa7351c1742fdeb9b4966f"}, - {file = "rpds_py-0.20.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d126b52e4a473d40232ec2052a8b232270ed1f8c9571aaf33f73a14cc298c24f"}, - {file = "rpds_py-0.20.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:c14937af98c4cc362a1d4374806204dd51b1e12dded1ae30645c298e5a5c4cb1"}, - {file = "rpds_py-0.20.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3d089d0b88996df627693639d123c8158cff41c0651f646cd8fd292c7da90eaf"}, - {file = "rpds_py-0.20.1-cp39-none-win32.whl", hash = "sha256:653647b8838cf83b2e7e6a0364f49af96deec64d2a6578324db58380cff82aca"}, - {file = "rpds_py-0.20.1-cp39-none-win_amd64.whl", hash = "sha256:fa41a64ac5b08b292906e248549ab48b69c5428f3987b09689ab2441f267d04d"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7a07ced2b22f0cf0b55a6a510078174c31b6d8544f3bc00c2bcee52b3d613f74"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:68cb0a499f2c4a088fd2f521453e22ed3527154136a855c62e148b7883b99f9a"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa3060d885657abc549b2a0f8e1b79699290e5d83845141717c6c90c2df38311"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:95f3b65d2392e1c5cec27cff08fdc0080270d5a1a4b2ea1d51d5f4a2620ff08d"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2cc3712a4b0b76a1d45a9302dd2f53ff339614b1c29603a911318f2357b04dd2"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d4eea0761e37485c9b81400437adb11c40e13ef513375bbd6973e34100aeb06"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f5179583d7a6cdb981151dd349786cbc318bab54963a192692d945dd3f6435d"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2fbb0ffc754490aff6dabbf28064be47f0f9ca0b9755976f945214965b3ace7e"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:a94e52537a0e0a85429eda9e49f272ada715506d3b2431f64b8a3e34eb5f3e75"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:92b68b79c0da2a980b1c4197e56ac3dd0c8a149b4603747c4378914a68706979"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:93da1d3db08a827eda74356f9f58884adb254e59b6664f64cc04cdff2cc19b0d"}, - {file = "rpds_py-0.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:754bbed1a4ca48479e9d4182a561d001bbf81543876cdded6f695ec3d465846b"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ca449520e7484534a2a44faf629362cae62b660601432d04c482283c47eaebab"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:9c4cb04a16b0f199a8c9bf807269b2f63b7b5b11425e4a6bd44bd6961d28282c"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb63804105143c7e24cee7db89e37cb3f3941f8e80c4379a0b355c52a52b6780"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:55cd1fa4ecfa6d9f14fbd97ac24803e6f73e897c738f771a9fe038f2f11ff07c"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f8f741b6292c86059ed175d80eefa80997125b7c478fb8769fd9ac8943a16c0"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fc212779bf8411667234b3cdd34d53de6c2b8b8b958e1e12cb473a5f367c338"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ad56edabcdb428c2e33bbf24f255fe2b43253b7d13a2cdbf05de955217313e6"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0a3a1e9ee9728b2c1734f65d6a1d376c6f2f6fdcc13bb007a08cc4b1ff576dc5"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:e13de156137b7095442b288e72f33503a469aa1980ed856b43c353ac86390519"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:07f59760ef99f31422c49038964b31c4dfcfeb5d2384ebfc71058a7c9adae2d2"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:59240685e7da61fb78f65a9f07f8108e36a83317c53f7b276b4175dc44151684"}, - {file = "rpds_py-0.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:83cba698cfb3c2c5a7c3c6bac12fe6c6a51aae69513726be6411076185a8b24a"}, - {file = "rpds_py-0.20.1.tar.gz", hash = "sha256:e1791c4aabd117653530dccd24108fa03cc6baf21f58b950d0a73c3b3b29a350"}, + {file = "rpds_py-0.21.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a017f813f24b9df929674d0332a374d40d7f0162b326562daae8066b502d0590"}, + {file = "rpds_py-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:20cc1ed0bcc86d8e1a7e968cce15be45178fd16e2ff656a243145e0b439bd250"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad116dda078d0bc4886cb7840e19811562acdc7a8e296ea6ec37e70326c1b41c"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:808f1ac7cf3b44f81c9475475ceb221f982ef548e44e024ad5f9e7060649540e"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de552f4a1916e520f2703ec474d2b4d3f86d41f353e7680b597512ffe7eac5d0"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:efec946f331349dfc4ae9d0e034c263ddde19414fe5128580f512619abed05f1"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b80b4690bbff51a034bfde9c9f6bf9357f0a8c61f548942b80f7b66356508bf5"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:085ed25baac88953d4283e5b5bd094b155075bb40d07c29c4f073e10623f9f2e"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:daa8efac2a1273eed2354397a51216ae1e198ecbce9036fba4e7610b308b6153"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:95a5bad1ac8a5c77b4e658671642e4af3707f095d2b78a1fdd08af0dfb647624"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3e53861b29a13d5b70116ea4230b5f0f3547b2c222c5daa090eb7c9c82d7f664"}, + {file = "rpds_py-0.21.0-cp310-none-win32.whl", hash = "sha256:ea3a6ac4d74820c98fcc9da4a57847ad2cc36475a8bd9683f32ab6d47a2bd682"}, + {file = "rpds_py-0.21.0-cp310-none-win_amd64.whl", hash = "sha256:b8f107395f2f1d151181880b69a2869c69e87ec079c49c0016ab96860b6acbe5"}, + {file = "rpds_py-0.21.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5555db3e618a77034954b9dc547eae94166391a98eb867905ec8fcbce1308d95"}, + {file = "rpds_py-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:97ef67d9bbc3e15584c2f3c74bcf064af36336c10d2e21a2131e123ce0f924c9"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ab2c2a26d2f69cdf833174f4d9d86118edc781ad9a8fa13970b527bf8236027"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4e8921a259f54bfbc755c5bbd60c82bb2339ae0324163f32868f63f0ebb873d9"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a7ff941004d74d55a47f916afc38494bd1cfd4b53c482b77c03147c91ac0ac3"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5145282a7cd2ac16ea0dc46b82167754d5e103a05614b724457cffe614f25bd8"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de609a6f1b682f70bb7163da745ee815d8f230d97276db049ab447767466a09d"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40c91c6e34cf016fa8e6b59d75e3dbe354830777fcfd74c58b279dceb7975b75"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d2132377f9deef0c4db89e65e8bb28644ff75a18df5293e132a8d67748397b9f"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0a9e0759e7be10109645a9fddaaad0619d58c9bf30a3f248a2ea57a7c417173a"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e20da3957bdf7824afdd4b6eeb29510e83e026473e04952dca565170cd1ecc8"}, + {file = "rpds_py-0.21.0-cp311-none-win32.whl", hash = "sha256:f71009b0d5e94c0e86533c0b27ed7cacc1239cb51c178fd239c3cfefefb0400a"}, + {file = "rpds_py-0.21.0-cp311-none-win_amd64.whl", hash = "sha256:e168afe6bf6ab7ab46c8c375606298784ecbe3ba31c0980b7dcbb9631dcba97e"}, + {file = "rpds_py-0.21.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:30b912c965b2aa76ba5168fd610087bad7fcde47f0a8367ee8f1876086ee6d1d"}, + {file = "rpds_py-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca9989d5d9b1b300bc18e1801c67b9f6d2c66b8fd9621b36072ed1df2c977f72"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f54e7106f0001244a5f4cf810ba8d3f9c542e2730821b16e969d6887b664266"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fed5dfefdf384d6fe975cc026886aece4f292feaf69d0eeb716cfd3c5a4dd8be"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:590ef88db231c9c1eece44dcfefd7515d8bf0d986d64d0caf06a81998a9e8cab"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f983e4c2f603c95dde63df633eec42955508eefd8d0f0e6d236d31a044c882d7"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b229ce052ddf1a01c67d68166c19cb004fb3612424921b81c46e7ea7ccf7c3bf"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ebf64e281a06c904a7636781d2e973d1f0926a5b8b480ac658dc0f556e7779f4"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:998a8080c4495e4f72132f3d66ff91f5997d799e86cec6ee05342f8f3cda7dca"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:98486337f7b4f3c324ab402e83453e25bb844f44418c066623db88e4c56b7c7b"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a78d8b634c9df7f8d175451cfeac3810a702ccb85f98ec95797fa98b942cea11"}, + {file = "rpds_py-0.21.0-cp312-none-win32.whl", hash = "sha256:a58ce66847711c4aa2ecfcfaff04cb0327f907fead8945ffc47d9407f41ff952"}, + {file = "rpds_py-0.21.0-cp312-none-win_amd64.whl", hash = "sha256:e860f065cc4ea6f256d6f411aba4b1251255366e48e972f8a347cf88077b24fd"}, + {file = "rpds_py-0.21.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ee4eafd77cc98d355a0d02f263efc0d3ae3ce4a7c24740010a8b4012bbb24937"}, + {file = "rpds_py-0.21.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:688c93b77e468d72579351a84b95f976bd7b3e84aa6686be6497045ba84be560"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c38dbf31c57032667dd5a2f0568ccde66e868e8f78d5a0d27dcc56d70f3fcd3b"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d6129137f43f7fa02d41542ffff4871d4aefa724a5fe38e2c31a4e0fd343fb0"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:520ed8b99b0bf86a176271f6fe23024323862ac674b1ce5b02a72bfeff3fff44"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aaeb25ccfb9b9014a10eaf70904ebf3f79faaa8e60e99e19eef9f478651b9b74"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af04ac89c738e0f0f1b913918024c3eab6e3ace989518ea838807177d38a2e94"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9b76e2afd585803c53c5b29e992ecd183f68285b62fe2668383a18e74abe7a3"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5afb5efde74c54724e1a01118c6e5c15e54e642c42a1ba588ab1f03544ac8c7a"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:52c041802a6efa625ea18027a0723676a778869481d16803481ef6cc02ea8cb3"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ee1e4fc267b437bb89990b2f2abf6c25765b89b72dd4a11e21934df449e0c976"}, + {file = "rpds_py-0.21.0-cp313-none-win32.whl", hash = "sha256:0c025820b78817db6a76413fff6866790786c38f95ea3f3d3c93dbb73b632202"}, + {file = "rpds_py-0.21.0-cp313-none-win_amd64.whl", hash = "sha256:320c808df533695326610a1b6a0a6e98f033e49de55d7dc36a13c8a30cfa756e"}, + {file = "rpds_py-0.21.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:2c51d99c30091f72a3c5d126fad26236c3f75716b8b5e5cf8effb18889ced928"}, + {file = "rpds_py-0.21.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cbd7504a10b0955ea287114f003b7ad62330c9e65ba012c6223dba646f6ffd05"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6dcc4949be728ede49e6244eabd04064336012b37f5c2200e8ec8eb2988b209c"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f414da5c51bf350e4b7960644617c130140423882305f7574b6cf65a3081cecb"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9afe42102b40007f588666bc7de82451e10c6788f6f70984629db193849dced1"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b929c2bb6e29ab31f12a1117c39f7e6d6450419ab7464a4ea9b0b417174f044"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8404b3717da03cbf773a1d275d01fec84ea007754ed380f63dfc24fb76ce4592"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e12bb09678f38b7597b8346983d2323a6482dcd59e423d9448108c1be37cac9d"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:58a0e345be4b18e6b8501d3b0aa540dad90caeed814c515e5206bb2ec26736fd"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:c3761f62fcfccf0864cc4665b6e7c3f0c626f0380b41b8bd1ce322103fa3ef87"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c2b2f71c6ad6c2e4fc9ed9401080badd1469fa9889657ec3abea42a3d6b2e1ed"}, + {file = "rpds_py-0.21.0-cp39-none-win32.whl", hash = "sha256:b21747f79f360e790525e6f6438c7569ddbfb1b3197b9e65043f25c3c9b489d8"}, + {file = "rpds_py-0.21.0-cp39-none-win_amd64.whl", hash = "sha256:0626238a43152918f9e72ede9a3b6ccc9e299adc8ade0d67c5e142d564c9a83d"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6b4ef7725386dc0762857097f6b7266a6cdd62bfd209664da6712cb26acef035"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6bc0e697d4d79ab1aacbf20ee5f0df80359ecf55db33ff41481cf3e24f206919"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da52d62a96e61c1c444f3998c434e8b263c384f6d68aca8274d2e08d1906325c"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:98e4fe5db40db87ce1c65031463a760ec7906ab230ad2249b4572c2fc3ef1f9f"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30bdc973f10d28e0337f71d202ff29345320f8bc49a31c90e6c257e1ccef4333"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:faa5e8496c530f9c71f2b4e1c49758b06e5f4055e17144906245c99fa6d45356"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32eb88c30b6a4f0605508023b7141d043a79b14acb3b969aa0b4f99b25bc7d4a"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a89a8ce9e4e75aeb7fa5d8ad0f3fecdee813802592f4f46a15754dcb2fd6b061"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:241e6c125568493f553c3d0fdbb38c74babf54b45cef86439d4cd97ff8feb34d"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:3b766a9f57663396e4f34f5140b3595b233a7b146e94777b97a8413a1da1be18"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:af4a644bf890f56e41e74be7d34e9511e4954894d544ec6b8efe1e21a1a8da6c"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3e30a69a706e8ea20444b98a49f386c17b26f860aa9245329bab0851ed100677"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:031819f906bb146561af051c7cef4ba2003d28cff07efacef59da973ff7969ba"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b876f2bc27ab5954e2fd88890c071bd0ed18b9c50f6ec3de3c50a5ece612f7a6"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc5695c321e518d9f03b7ea6abb5ea3af4567766f9852ad1560f501b17588c7b"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b4de1da871b5c0fd5537b26a6fc6814c3cc05cabe0c941db6e9044ffbb12f04a"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:878f6fea96621fda5303a2867887686d7a198d9e0f8a40be100a63f5d60c88c9"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8eeec67590e94189f434c6d11c426892e396ae59e4801d17a93ac96b8c02a6c"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff2eba7f6c0cb523d7e9cff0903f2fe1feff8f0b2ceb6bd71c0e20a4dcee271"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a429b99337062877d7875e4ff1a51fe788424d522bd64a8c0a20ef3021fdb6ed"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:d167e4dbbdac48bd58893c7e446684ad5d425b407f9336e04ab52e8b9194e2ed"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:4eb2de8a147ffe0626bfdc275fc6563aa7bf4b6db59cf0d44f0ccd6ca625a24e"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e78868e98f34f34a88e23ee9ccaeeec460e4eaf6db16d51d7a9b883e5e785a5e"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:4991ca61656e3160cdaca4851151fd3f4a92e9eba5c7a530ab030d6aee96ec89"}, + {file = "rpds_py-0.21.0.tar.gz", hash = "sha256:ed6378c9d66d0de903763e7706383d60c33829581f0adff47b6535f1802fa6db"}, ] [[package]] @@ -4432,6 +4456,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -5210,4 +5239,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a53f642c91942635ab0b62af62591c9d146d58f7b91d8f04799f3730dca4f5f0" +content-hash = "f2945afc48a70dc63c06878c69d8fd4e4905c9775c6ce0ce831cfcd9202c1882" diff --git a/pyproject.toml b/pyproject.toml index e92ae626df..f3375b2090 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,6 @@ pyyaml = "^6.0.2" pyaml-env = "^1.2.1" python-dotenv = "^1.0.0" -# Network -tenacity = "^9.0.0" - pydantic = "^2.9.2" rich = "^13.6.0" devtools = "^0.12.2" @@ -87,11 +84,10 @@ typing-extensions = "^4.12.2" #Azure azure-storage-blob = "^12.22.0" azure-identity = "^1.17.1" -json-repair = "^0.30.0" future = "^1.0.0" # Needed until graspologic fixes their dependency typer = "^0.12.5" - +fnllm = "^0.0.8" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" @@ -110,6 +106,7 @@ deptry = "^0.20.0" mkdocs-material = "^9.5.39" mkdocs-jupyter = "^0.25.0" mkdocs-exclude-search = "^0.6.6" +pytest-dotenv = "^0.5.2" mkdocs-typer = "^0.0.3" [build-system] @@ -273,5 +270,4 @@ exclude = ["**/node_modules", "**/__pycache__"] [tool.pytest.ini_options] asyncio_mode = "auto" timeout = 1000 -# log_cli = true -# log_cli_level = "INFO" +env_files = [".env"] diff --git a/tests/unit/indexing/verbs/helpers/mock_llm.py b/tests/unit/indexing/verbs/helpers/mock_llm.py index ba27da9c68..c9464e41fa 100644 --- a/tests/unit/indexing/verbs/helpers/mock_llm.py +++ b/tests/unit/indexing/verbs/helpers/mock_llm.py @@ -1,10 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.llm import CompletionLLM, MockChatLLM +from fnllm import ChatLLM +from pydantic import BaseModel + +from graphrag.index.llm.mock_llm import MockChatLLM def create_mock_llm( - responses: list[str], -) -> CompletionLLM: + responses: list[str | BaseModel], +) -> ChatLLM: """Creates a mock LLM that returns the given responses.""" return MockChatLLM(responses) diff --git a/tests/unit/llm/__init__.py b/tests/unit/llm/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/llm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/unit/llm/base/__init__.py b/tests/unit/llm/base/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/llm/base/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/unit/llm/base/test_caching_llm.py b/tests/unit/llm/base/test_caching_llm.py deleted file mode 100644 index 28e38f98e6..0000000000 --- a/tests/unit/llm/base/test_caching_llm.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -"""Caching LLM Tests.""" - -import asyncio -from typing import Any, cast - -from graphrag.llm import CompletionLLM, LLMOutput -from graphrag.llm.base.caching_llm import CachingLLM -from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM -from graphrag.llm.types import LLMCache - - -class TestCache(LLMCache): - def __init__(self): - self.cache = {} - - async def has(self, key: str) -> bool: - return key in self.cache - - async def get(self, key: str) -> dict | None: - entry = self.cache.get(key) - return entry["result"] if entry else None - - async def set( - self, key: str, value: str, debug_data: dict[str, Any] | None = None - ) -> None: - self.cache[key] = {"result": value, **(debug_data or {})} - - -async def mock_responder(input: str, **kwargs: dict) -> LLMOutput: - await asyncio.sleep(0.0001) - return LLMOutput(output=f"response to [{input}]") - - -def throwing_responder(input: str, **kwargs: dict) -> LLMOutput: - raise ValueError - - -mock_responder_llm = cast(CompletionLLM, mock_responder) -throwing_llm = cast(CompletionLLM, throwing_responder) - - -async def test_caching_llm() -> None: - """Test a composite LLM.""" - llm = CachingLLM( - mock_responder_llm, llm_parameters={}, operation="test", cache=TestCache() - ) - response = await llm("input 1") - assert response.output == "response to [input 1]" - llm.set_delegate(throwing_llm) - response = await llm("input 1") - assert response.output == "response to [input 1]" - - -async def test_composite_llm() -> None: - """Test a composite LLM.""" - caching = CachingLLM( - mock_responder_llm, llm_parameters={}, operation="test", cache=TestCache() - ) - llm = OpenAIHistoryTrackingLLM(caching) - - response = await llm("input 1") - history: list[dict] = cast(list[dict], response.history) - assert len(history) == 2 - - response = await llm("input 1") - history: list[dict] = cast(list[dict], response.history) - assert len(history) == 2 - - response = await llm("input 2", history=history) - history: list[dict] = cast(list[dict], response.history) - assert len(history) == 4 diff --git a/tests/unit/llm/openai/__init__.py b/tests/unit/llm/openai/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/llm/openai/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/unit/llm/openai/test_history_tracking_llm.py b/tests/unit/llm/openai/test_history_tracking_llm.py deleted file mode 100644 index 18ae9f5171..0000000000 --- a/tests/unit/llm/openai/test_history_tracking_llm.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -"""History-tracking LLM Tests.""" - -import asyncio -from typing import cast - -from graphrag.llm import CompletionLLM, LLMOutput -from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM - - -async def test_history_tracking_llm() -> None: - async def mock_responder(input: str, **kwargs: dict) -> LLMOutput: - await asyncio.sleep(0.0001) - return LLMOutput(output=f"response to [{input}]") - - delegate = cast(CompletionLLM, mock_responder) - llm = OpenAIHistoryTrackingLLM(delegate) - - response = await llm("input 1") - history: list[dict] = cast(list[dict], response.history) - assert len(history) == 2 - assert history[0] == {"role": "user", "content": "input 1"} - assert history[1] == {"role": "assistant", "content": "response to [input 1]"} - - response = await llm("input 2", history=history) - history: list[dict] = cast(list[dict], response.history) - assert len(history) == 4 - assert history[0] == {"role": "user", "content": "input 1"} - assert history[1] == {"role": "assistant", "content": "response to [input 1]"} - assert history[2] == {"role": "user", "content": "input 2"} - assert history[3] == {"role": "assistant", "content": "response to [input 2]"} diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 50ef7abaa3..1a04e9c56f 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -1,12 +1,15 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -import json import pytest from datashaper.errors import VerbParallelizationError from graphrag.config.enums import LLMType +from graphrag.index.graph.extractors.community_reports.community_reports_extractor import ( + CommunityReportResponse, + FindingModel, +) from graphrag.index.workflows.v1.create_final_community_reports import ( build_steps, workflow_name, @@ -21,25 +24,27 @@ ) MOCK_RESPONSES = [ - json.dumps({ - "title": "", - "summary": "", - "rating": 2, - "rating_explanation": "", - "findings": [ - { - "summary": "", - "explanation": "", - "explanation": "