Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Fine Tuning part 1: Indexing prompt generation #254

Merged
merged 12 commits into from
Jun 3, 2024
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240601003136314078.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Indexing prompt fine tuning"
}
3 changes: 3 additions & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ Tahbaz
payus
dulce
Asadi
ABILA
Abila
POKRALLY

# English
skippable
Expand Down
4 changes: 2 additions & 2 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def resolved_strategy(self) -> dict:

return self.strategy or {
"type": ChunkStrategyType.tokens,
"size": self.size,
"overlap": self.overlap,
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
}
3 changes: 3 additions & 0 deletions graphrag/config/models/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class InputConfig(BaseModel):
file_pattern: str = Field(
description="The input file pattern to use.", default=defs.INPUT_CSV_PATTERN
)
file_filter: dict[str, str] | None = Field(
description="The optional file filter for the input files.", default=None
)
source_column: str | None = Field(
description="The input source column to use.", default=None
)
Expand Down
4 changes: 4 additions & 0 deletions graphrag/fine_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Command line interface for the fine_tune module."""
101 changes: 101 additions & 0 deletions graphrag/fine_tune/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Query Engine package root."""

import argparse
import asyncio
from enum import Enum

from graphrag.fine_tune.generator import MAX_TOKEN_COUNT
from graphrag.fine_tune.loader import MIN_CHUNK_SIZE

from .cli import fine_tune


class DocSelectionType(Enum):
"""The type of document selection to use."""

ALL = "all"
RANDOM = "random"
TOP = "top"

def __str__(self):
"""Return the string representation of the enum value."""
return self.value


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--root",
help="The data project root.",
required=False,
type=str,
)

parser.add_argument(
"--domain",
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'.",
required=False,
default="",
type=str,
)

parser.add_argument(
"--method",
help="The method to select documents, one of: all, random or top",
required=True,
type=DocSelectionType,
choices=list(DocSelectionType),
default=DocSelectionType.TOP,
)

parser.add_argument(
"--limit",
help="The limit of files to load when doing random or top selection",
type=int,
required=False,
default=5,
)

parser.add_argument(
"--max_tokens",
help="Max token count for prompt generation",
type=int,
required=False,
default=MAX_TOKEN_COUNT,
)

parser.add_argument(
"--chunk_size",
help="Max token count for prompt generation",
type=int,
required=False,
default=MIN_CHUNK_SIZE,
)

parser.add_argument(
"--output",
help="Folder to save the generated prompts to",
type=str,
required=False,
default="prompts",
)

args = parser.parse_args()

loop = asyncio.get_event_loop()

loop.run_until_complete(
fine_tune(
args.root,
args.domain,
str(args.method),
args.limit,
args.max_tokens,
args.chunk_size,
args.output,
)
)
165 changes: 165 additions & 0 deletions graphrag/fine_tune/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Command line interface for the fine_tune module."""

from pathlib import Path

from datashaper import NoopVerbCallbacks

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.fine_tune.generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
create_entity_summarization_prompt,
generate_community_reporter_role,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)
from graphrag.fine_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
read_config_parameters,
)
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import CompletionLLM


async def fine_tune(
root: str,
domain: str,
select: str = "top",
limit: int = 5,
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
output: str = "prompts",
):
"""Fine tune the model.

Parameters
----------
- root: The root directory.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use.
- output: The output folder to store the prompts.
"""
reporter = PrintProgressReporter("")
config = read_config_parameters(root, reporter)

output_path = Path(config.root_dir) / output

doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=select,
reporter=reporter,
chunk_size=chunk_size,
)

# Create LLM from config
llm = load_llm(
"fine_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
)

await generate_indexing_prompts(
llm, config, doc_list, output_path, reporter, domain, max_tokens
)


async def generate_indexing_prompts(
llm: CompletionLLM,
config: GraphRagConfig,
doc_list: list[str],
output_path: Path,
reporter: ProgressReporter,
domain: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
):
"""Generate indexing prompts.

Parameters
----------
- llm: The LLM model to use.
- config: The GraphRag configuration.
- doc_list: The list of documents to use.
- output_path: The path to store the prompts.
- reporter: The progress reporter.
- domain: The domain to map the input documents to.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
"""
if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")

reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info(f"Generated persona: {persona}")

reporter.info("Generating entity types")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
)
reporter.info(f"Generated entity types: {entity_types}")

reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
reporter.info("Done generating entity relationship examples")

reporter.info("Generating entity extraction prompt...")
create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
model_name=config.llm.model,
output_path=output_path,
max_token_count=max_tokens,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")

reporter.info("Generating entity summarization prompt...")
create_entity_summarization_prompt(
persona=persona,
output_path=output_path,
)
reporter.info(
f"Generated entity summarization prompt, stored in folder {output_path}"
)

reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(f"Generated community reporter role: {community_reporter_role}")

reporter.info("Generating community summarization prompt...")
create_community_summarization_prompt(
persona=persona, role=community_reporter_role, output_path=output_path
)
reporter.info(
f"Generated community summarization prompt, stored in folder {output_path}"
)
26 changes: 26 additions & 0 deletions graphrag/fine_tune/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Prompt generation module."""

from .community_report_summarization import create_community_summarization_prompt
from .community_reporter_role import generate_community_reporter_role
from .defaults import MAX_TOKEN_COUNT
from .domain import generate_domain
from .entity_extraction_prompt import create_entity_extraction_prompt
from .entity_relationship import generate_entity_relationship_examples
from .entity_summarization_prompt import create_entity_summarization_prompt
from .entity_types import generate_entity_types
from .persona import generate_persona

__all__ = [
"MAX_TOKEN_COUNT",
"create_community_summarization_prompt",
"create_entity_extraction_prompt",
"create_entity_summarization_prompt",
"generate_community_reporter_role",
"generate_domain",
"generate_entity_relationship_examples",
"generate_entity_types",
"generate_persona",
]
40 changes: 40 additions & 0 deletions graphrag/fine_tune/generator/community_report_summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Module for generating prompts for community report summarization."""

from pathlib import Path

from graphrag.fine_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT

COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt"


def create_community_summarization_prompt(
persona: str,
role: str,
output_path: Path | None = None,
) -> str:
"""Create a prompt for community summarization. If output_path is provided, write the prompt to a file.

Parameters
----------
- persona (str): The persona to use for the community summarization prompt
- role (str): The role to use for the community summarization prompt
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.

Returns
-------
- str: The community summarization prompt
"""
prompt = COMMUNITY_REPORT_SUMMARIZATION_PROMPT.format(persona=persona, role=role)

if output_path:
output_path.mkdir(parents=True, exist_ok=True)

output_path = output_path / COMMUNITY_SUMMARIZATION_FILENAME
# Write file to output path
with output_path.open("w") as file:
file.write(prompt)

return prompt
Loading
Loading