Skip to content

Commit

Permalink
Remove graphrag.llm, replace with fnllm (#1315)
Browse files Browse the repository at this point in the history
* add fnllm; remove llm folder

* remove llm unit tests

* update imports

* update imports

* formatting

* enable autosave

* update mockllm

* update community reports extractor

* move most llm usage to fnllm

* update type issues

* fix unit tests

* type updates

* update dictionary

* semver

* update llm construction, get integration tests working

* load from llmparameters model

* move ruff settings to ruff.toml

* add gitattributes file

* ignore ruff.toml spelling

* update .gitattributes

* update gitignore

* update config construction

* update prompt var usage

* add cache adapter

* use cache adapter in embeddings calls

* update embedding strategy

* add fnllm

* add pytest-dotenv

* fix some verb tests

* get verbtests running

* update ruff.toml for vscode

* enable ruff native server in vscode

* update artifact inspecting code

* remove local-test update

* use string.replace instead of string.format in community reprots etxractor

* bump timeout

* revert ruff.toml, vscode settings for another pr

* revert cspell config

* revert gitignore

* remove json-repair, update fnllm

* use fnllm generic type interfaces

* update load_llm to use target models

* consolidate chat parameters

* add 'extra_attributes' prop to community report response

* formatting

* update fnllm

* formatting

* formatting

* Add defaults to some llm params to avoid null on params hash

* Formatting

---------

Co-authored-by: Alonso Guevara <[email protected]>
Co-authored-by: Josh Bradley <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent d43124e commit 5ff2d3c
Show file tree
Hide file tree
Showing 77 changed files with 670 additions and 2,747 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241024210728482023.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "replace llm package with fnllm"
}
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pypi
nbformat
semversioner
mkdocs
fnllm
typer

# Library Methods
Expand Down
7 changes: 3 additions & 4 deletions graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ async def generate_indexing_prompts(
# Create LLM from config
llm = load_llm(
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
config.llm,
cache=None,
callbacks=NoopVerbCallbacks(),
)

if not domain:
Expand Down
4 changes: 1 addition & 3 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,7 @@ class Section(str, Enum):

def _is_azure(llm_type: LLMType | None) -> bool:
return (
llm_type == LLMType.AzureOpenAI
or llm_type == LLMType.AzureOpenAIChat
or llm_type == LLMType.AzureOpenAIEmbedding
llm_type == LLMType.AzureOpenAIChat or llm_type == LLMType.AzureOpenAIEmbedding
)


Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

ASYNC_MODE = AsyncType.Threaded
ENCODING_MODEL = "cl100k_base"
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
#
# LLM Parameters
#
LLM_FREQUENCY_PENALTY = 0.0
LLM_TYPE = LLMType.OpenAIChat
LLM_MODEL = "gpt-4-turbo-preview"
LLM_MAX_TOKENS = 4000
Expand All @@ -34,6 +36,7 @@
LLM_REQUESTS_PER_MINUTE = 0
LLM_MAX_RETRIES = 10
LLM_MAX_RETRY_WAIT = 10.0
LLM_PRESENCE_PENALTY = 0.0
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
LLM_CONCURRENT_REQUESTS = 25

Expand Down
4 changes: 0 additions & 4 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ class LLMType(str, Enum):
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"

# Raw Completion
OpenAI = "openai"
AzureOpenAI = "azure_openai"

# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
Expand Down
17 changes: 17 additions & 0 deletions graphrag/config/models/llm_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class LLMParameters(BaseModel):
type: LLMType = Field(
description="The type of LLM model to use.", default=defs.LLM_TYPE
)
encoding_model: str | None = Field(
description="The encoding model to use", default=defs.ENCODING_MODEL
)
model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL)
embeddings_model: str | None = Field(
description="The embeddings model to use.", default=defs.EMBEDDING_MODEL
)
max_tokens: int | None = Field(
description="The maximum number of tokens to generate.",
default=defs.LLM_MAX_TOKENS,
Expand All @@ -37,6 +43,14 @@ class LLMParameters(BaseModel):
description="The number of completions to generate.",
default=defs.LLM_N,
)
frequency_penalty: float | None = Field(
description="The frequency penalty to use for token generation.",
default=defs.LLM_FREQUENCY_PENALTY,
)
presence_penalty: float | None = Field(
description="The presence penalty to use for token generation.",
default=defs.LLM_PRESENCE_PENALTY,
)
request_timeout: float = Field(
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
)
Expand Down Expand Up @@ -86,3 +100,6 @@ class LLMParameters(BaseModel):
description="Whether to use concurrent requests for the LLM service.",
default=defs.LLM_CONCURRENT_REQUESTS,
)
responses: list[str | BaseModel] | None = Field(
default=None, description="Static responses to use in mock mode."
)
24 changes: 12 additions & 12 deletions graphrag/index/graph/extractors/claims/claim_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from typing import Any

import tiktoken
from fnllm import ChatLLM

import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.llm import CompletionLLM
from graphrag.prompts.index.claim_extraction import (
CLAIM_EXTRACTION_PROMPT,
CONTINUE_PROMPT,
Expand All @@ -36,7 +36,7 @@ class ClaimExtractorResult:
class ClaimExtractor:
"""Claim extractor class definition."""

_llm: CompletionLLM
_llm: ChatLLM
_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
Expand All @@ -48,10 +48,11 @@ class ClaimExtractor:
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn
_loop_args: dict[str, Any]

def __init__(
self,
llm_invoker: CompletionLLM,
llm_invoker: ChatLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
Expand Down Expand Up @@ -87,9 +88,9 @@ def __init__(

# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
yes = f"{encoding.encode('YES')[0]}"
no = f"{encoding.encode('NO')[0]}"
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}

async def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
Expand Down Expand Up @@ -164,13 +165,12 @@ async def _process_document(
)

response = await self._llm(
self._extraction_prompt,
variables={
self._extraction_prompt.format(**{
self._input_text_key: doc,
**prompt_args,
},
})
)
results = response.output or ""
results = response.output.content or ""
claims = results.strip().removesuffix(completion_delimiter)

# Repeat to ensure we maximize entity count
Expand All @@ -180,7 +180,7 @@ async def _process_document(
name=f"extract-continuation-{i}",
history=response.history,
)
extension = response.output or ""
extension = response.output.content or ""
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)
Expand All @@ -195,7 +195,7 @@ async def _process_document(
history=response.history,
model_parameters=self._loop_args,
)
if response.output != "YES":
if response.output.content != "YES":
break

return self._parse_claim_tuples(results, prompt_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,50 @@
from dataclasses import dataclass
from typing import Any

from fnllm import ChatLLM
from pydantic import BaseModel, Field

from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.dicts import dict_has_keys_with_types
from graphrag.llm import CompletionLLM
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT

log = logging.getLogger(__name__)


class FindingModel(BaseModel):
"""A model for the expected LLM response shape."""

summary: str = Field(description="The summary of the finding.")
explanation: str = Field(description="An explanation of the finding.")


class CommunityReportResponse(BaseModel):
"""A model for the expected LLM response shape."""

title: str = Field(description="The title of the report.")
summary: str = Field(description="A summary of the report.")
findings: list[FindingModel] = Field(
description="A list of findings in the report."
)
rating: float = Field(description="The rating of the report.")
rating_explanation: str = Field(description="An explanation of the rating.")

extra_attributes: dict[str, Any] = Field(
default_factory=dict, description="Extra attributes."
)


@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""

output: str
structured_output: dict
structured_output: CommunityReportResponse | None


class CommunityReportsExtractor:
"""Community reports extractor class definition."""

_llm: CompletionLLM
_llm: ChatLLM
_input_text_key: str
_extraction_prompt: str
_output_formatter_prompt: str
Expand All @@ -36,7 +60,7 @@ class CommunityReportsExtractor:

def __init__(
self,
llm_invoker: CompletionLLM,
llm_invoker: ChatLLM,
input_text_key: str | None = None,
extraction_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
Expand All @@ -53,55 +77,30 @@ async def __call__(self, inputs: dict[str, Any]):
"""Call method definition."""
output = None
try:
response = (
await self._llm(
self._extraction_prompt,
json=True,
name="create_community_report",
variables={self._input_text_key: inputs[self._input_text_key]},
is_response_valid=lambda x: dict_has_keys_with_types(
x,
[
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
],
inplace=True,
),
model_parameters={"max_tokens": self._max_report_length},
)
or {}
input_text = inputs[self._input_text_key]
prompt = self._extraction_prompt.replace(
"{" + self._input_text_key + "}", input_text
)
response = await self._llm(
prompt,
json=True,
name="create_community_report",
json_model=CommunityReportResponse,
model_parameters={"max_tokens": self._max_report_length},
)
output = response.json or {}
output = response.parsed_json
except Exception as e:
log.exception("error generating community report")
self._on_error(e, traceback.format_exc(), None)
output = {}

text_output = self._get_text_output(output)
text_output = self._get_text_output(output) if output else ""
return CommunityReportsResult(
structured_output=output,
output=text_output,
)

def _get_text_output(self, parsed_output: dict) -> str:
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])

def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")

def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")

def _get_text_output(self, report: CommunityReportResponse) -> str:
report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
f"## {f.summary}\n\n{f.explanation}" for f in report.findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"
return f"# {report.title}\n\n{report.summary}\n\n{report_sections}"
21 changes: 10 additions & 11 deletions graphrag/index/graph/extractors/graph/graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import networkx as nx
import tiktoken
from fnllm import ChatLLM

import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.string import clean_str
from graphrag.llm import CompletionLLM
from graphrag.prompts.index.entity_extraction import (
CONTINUE_PROMPT,
GRAPH_EXTRACTION_PROMPT,
Expand All @@ -40,7 +40,7 @@ class GraphExtractionResult:
class GraphExtractor:
"""Unipartite graph extractor class definition."""

_llm: CompletionLLM
_llm: ChatLLM
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
Expand All @@ -57,7 +57,7 @@ class GraphExtractor:

def __init__(
self,
llm_invoker: CompletionLLM,
llm_invoker: ChatLLM,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
Expand Down Expand Up @@ -90,9 +90,9 @@ def __init__(

# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
yes = f"{encoding.encode('YES')[0]}"
no = f"{encoding.encode('NO')[0]}"
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}

async def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
Expand Down Expand Up @@ -151,13 +151,12 @@ async def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
response = await self._llm(
self._extraction_prompt,
variables={
self._extraction_prompt.format(**{
**prompt_variables,
self._input_text_key: text,
},
}),
)
results = response.output or ""
results = response.output.content or ""

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
Expand All @@ -166,7 +165,7 @@ async def _process_document(
name=f"extract-continuation-{i}",
history=response.history,
)
results += response.output or ""
results += response.output.content or ""

# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
Expand Down
Loading

0 comments on commit 5ff2d3c

Please sign in to comment.