Skip to content

Commit

Permalink
Add formatting and ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlonsoGuevara committed Jun 1, 2024
1 parent ea179ce commit 973a028
Show file tree
Hide file tree
Showing 27 changed files with 294 additions and 153 deletions.
2 changes: 2 additions & 0 deletions graphrag/fine_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Command line interface for the fine_tune module."""
113 changes: 80 additions & 33 deletions graphrag/fine_tune/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@

from datashaper import NoopVerbCallbacks

from graphrag.fine_tune.generator import generate_entity_relationship_examples
from graphrag.fine_tune.loader import read_config_parameters


from graphrag.fine_tune.loader import MIN_CHUNK_SIZE, load_docs_in_chunks
from graphrag.index.llm import load_llm


from graphrag.index.progress import PrintProgressReporter
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.fine_tune.generator import (
generate_persona,
generate_domain,
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
generate_entity_types,
create_entity_summarization_prompt,
generate_community_reporter_role,
create_community_summarization_prompt,
MAX_TOKEN_COUNT,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)

reporter = PrintProgressReporter("")
from graphrag.fine_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
read_config_parameters,
)
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import CompletionLLM


async def fine_tune(
Expand All @@ -38,9 +38,20 @@ async def fine_tune(
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
output: str = "prompts",
**kwargs,
):
"""Fine tune the model."""
"""Fine tune the model.
Parameters
----------
- root: The root directory.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use.
- output: The output folder to store the prompts.
"""
reporter = PrintProgressReporter("")
config = read_config_parameters(root, reporter)

output_path = Path(config.root_dir) / output
Expand All @@ -63,56 +74,92 @@ async def fine_tune(
config.llm.model_dump(),
)

await generate_indexing_prompts(
llm, config, doc_list, output_path, reporter, domain, max_tokens
)


async def generate_indexing_prompts(
llm: CompletionLLM,
config: GraphRagConfig,
doc_list: list[str],
output_path: Path,
reporter: ProgressReporter,
domain: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
):
"""Generate indexing prompts.
Parameters
----------
- llm: The LLM model to use.
- config: The GraphRag configuration.
- doc_list: The list of documents to use.
- output_path: The path to store the prompts.
- reporter: The progress reporter.
- domain: The domain to map the input documents to.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
"""
if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
print(domain)
reporter.info(f"Generated domain: {domain}")

reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
print(persona)
reporter.info(f"Generated persona: {persona}")

reporter.info("Generating entity types")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
)
print(entity_types)
reporter.info(f"Generated entity types: {entity_types}")

reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
print(examples)
reporter.info("Done generating entity relationship examples")

prompt = create_entity_extraction_prompt(
reporter.info("Generating entity extraction prompt...")
create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
json_mode=config.llm.model_supports_json or False,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
model_name=config.llm.model,
output_path=output_path,
max_token_count=max_tokens,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")

print(prompt)

prompt = create_entity_summarization_prompt(
reporter.info("Generating entity summarization prompt...")
create_entity_summarization_prompt(
persona=persona,
output_path=output_path,
)
reporter.info(
f"Generated entity summarization prompt, stored in folder {output_path}"
)

reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(f"Generated community reporter role: {community_reporter_role}")

print(community_reporter_role)

prompt = create_community_summarization_prompt(
reporter.info("Generating community summarization prompt...")
create_community_summarization_prompt(
persona=persona, role=community_reporter_role, output_path=output_path
)

print(prompt)
reporter.info(
f"Generated community summarization prompt, stored in folder {output_path}"
)
26 changes: 14 additions & 12 deletions graphrag/fine_tune/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from .entity_relationship import generate_entity_relationship_examples
from .entity_types import generate_entity_types
from .persona import generate_persona
"""Prompt generation module."""

from .community_report_summarization import create_community_summarization_prompt
from .community_reporter_role import generate_community_reporter_role
from .defaults import MAX_TOKEN_COUNT
from .domain import generate_domain
from .entity_extraction_prompt import create_entity_extraction_prompt
from .defaults import MAX_TOKEN_COUNT
from .entity_relationship import generate_entity_relationship_examples
from .entity_summarization_prompt import create_entity_summarization_prompt
from .community_reporter_role import generate_community_reporter_role
from .community_report_summarization import create_community_summarization_prompt

from .entity_types import generate_entity_types
from .persona import generate_persona

__all__ = [
"generate_entity_relationship_examples",
"generate_entity_types",
"generate_persona",
"generate_domain",
"MAX_TOKEN_COUNT",
"create_community_summarization_prompt",
"create_entity_extraction_prompt",
"create_entity_summarization_prompt",
"generate_community_reporter_role",
"MAX_TOKEN_COUNT",
"generate_domain",
"generate_entity_relationship_examples",
"generate_entity_types",
"generate_persona",
]
19 changes: 16 additions & 3 deletions graphrag/fine_tune/generator/community_report_summarization.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Module for generating prompts for community report summarization."""

from pathlib import Path
from graphrag.fine_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT

from graphrag.fine_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT

COMMUNITY_SUMMARIZATION_FILENAME = "community_report_summarization_prompt.txt"
COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt"


def create_community_summarization_prompt(
persona: str,
role: str,
output_path: Path | None = None,
) -> str:

"""Create a prompt for community summarization. If output_path is provided, write the prompt to a file.
Parameters
----------
- persona (str): The persona to use for the community summarization prompt
- role (str): The role to use for the community summarization prompt
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
Returns
-------
- str: The community summarization prompt
"""
prompt = COMMUNITY_REPORT_SUMMARIZATION_PROMPT.format(persona=persona, role=role)

if output_path:
Expand Down
16 changes: 14 additions & 2 deletions graphrag/fine_tune/generator/community_reporter_role.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Generate a community reporter role for community summarization."""

from graphrag.fine_tune.prompt.community_reporter_role import (
from graphrag.fine_tune.prompt import (
GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT,
)
from graphrag.llm.types.llm_types import CompletionLLM
Expand All @@ -11,8 +12,19 @@
async def generate_community_reporter_role(
llm: CompletionLLM, domain: str, persona: str, docs: str | list[str]
) -> str:
"""Provided a community reporter role, generate an LLM persona to use for GraphRAG prompts"""
"""Generate an LLM persona to use for GraphRAG prompts.
Parameters
----------
- llm (CompletionLLM): The LLM to use for generation
- domain (str): The domain to generate a persona for
- persona (str): The persona to generate a role for
- docs (str | list[str]): The domain to generate a persona for
Returns
-------
- str: The generated domain prompt response.
"""
docs_str = " ".join(docs) if isinstance(docs, list) else docs
domain_prompt = GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT.format(
domain=domain, persona=persona, input_text=docs_str
Expand Down
2 changes: 2 additions & 0 deletions graphrag/fine_tune/generator/defaults.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Default values for the fine-tuning module."""

DEFAULT_TASK = """
Identify the relations and structure of the community of interest, specifically within the {domain} domain.
"""
Expand Down
12 changes: 11 additions & 1 deletion graphrag/fine_tune/generator/domain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Domain generation for GraphRAG prompts."""

from graphrag.fine_tune.prompt.domain import GENERATE_DOMAIN_PROMPT
from graphrag.llm.types.llm_types import CompletionLLM


async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str:
"""Provided a domain and a task, generate an LLM persona to use for GraphRAG prompts"""
"""Generate an LLM persona to use for GraphRAG prompts.
Parameters
----------
- llm (CompletionLLM): The LLM to use for generation
- docs (str | list[str]): The domain to generate a persona for
Returns
-------
- str: The generated domain prompt response.
"""
docs_str = " ".join(docs) if isinstance(docs, list) else docs
domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str)

Expand Down
31 changes: 23 additions & 8 deletions graphrag/fine_tune/generator/entity_extraction_prompt.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Entity Extraction prompt generator module."""

from pathlib import Path

from graphrag.fine_tune.template import (
GRAPH_EXTRACTION_PROMPT,
GRAPH_EXTRACTION_JSON_PROMPT,
EXAMPLE_EXTRACTION_TEMPLATE,
GRAPH_EXTRACTION_JSON_PROMPT,
GRAPH_EXTRACTION_PROMPT,
)

from graphrag.index.utils.tokens import num_tokens_from_string

ENTITY_EXTRACTION_FILENAME = "entity_extraction_prompt.txt"
ENTITY_EXTRACTION_FILENAME = "entity_extraction.txt"


def create_entity_extraction_prompt(
Expand All @@ -22,7 +24,23 @@ def create_entity_extraction_prompt(
json_mode: bool = False,
output_path: Path | None = None,
) -> str:

"""
Create a prompt for entity extraction.
Parameters
----------
- entity_types (str | list[str]): The entity types to extract
- docs (list[str]): The list of documents to extract entities from
- examples (list[str]): The list of examples to use for entity extraction
- model_name (str): The name of the model to use for token counting
- max_token_count (int): The maximum number of tokens to use for the prompt
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
Returns
-------
- str: The entity extraction prompt
"""
prompt = GRAPH_EXTRACTION_JSON_PROMPT if json_mode else GRAPH_EXTRACTION_PROMPT
if isinstance(entity_types, list):
entity_types = ", ".join(entity_types)
Expand All @@ -42,9 +60,6 @@ def create_entity_extraction_prompt(
n=i + 1, input_text=input, entity_types=entity_types, output=output
)

print(
f"Input tokens {num_tokens_from_string(input, model_name)}, output tokens {num_tokens_from_string(output, model_name)}"
)
example_tokens = num_tokens_from_string(example_formatted, model=model_name)

# Squeeze in at least one example
Expand Down
Loading

0 comments on commit 973a028

Please sign in to comment.