diff --git a/graphrag/fine_tune/__init__.py b/graphrag/fine_tune/__init__.py index 0a3e38adfb..2384b5793c 100644 --- a/graphrag/fine_tune/__init__.py +++ b/graphrag/fine_tune/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License + +"""Command line interface for the fine_tune module.""" diff --git a/graphrag/fine_tune/cli.py b/graphrag/fine_tune/cli.py index 046f73597e..2ba6686ec1 100644 --- a/graphrag/fine_tune/cli.py +++ b/graphrag/fine_tune/cli.py @@ -7,27 +7,27 @@ from datashaper import NoopVerbCallbacks -from graphrag.fine_tune.generator import generate_entity_relationship_examples -from graphrag.fine_tune.loader import read_config_parameters - - -from graphrag.fine_tune.loader import MIN_CHUNK_SIZE, load_docs_in_chunks -from graphrag.index.llm import load_llm - - -from graphrag.index.progress import PrintProgressReporter +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.fine_tune.generator import ( - generate_persona, - generate_domain, + MAX_TOKEN_COUNT, + create_community_summarization_prompt, create_entity_extraction_prompt, - generate_entity_types, create_entity_summarization_prompt, generate_community_reporter_role, - create_community_summarization_prompt, - MAX_TOKEN_COUNT, + generate_domain, + generate_entity_relationship_examples, + generate_entity_types, + generate_persona, ) - -reporter = PrintProgressReporter("") +from graphrag.fine_tune.loader import ( + MIN_CHUNK_SIZE, + load_docs_in_chunks, + read_config_parameters, +) +from graphrag.index.llm import load_llm +from graphrag.index.progress import PrintProgressReporter +from graphrag.index.progress.types import ProgressReporter +from graphrag.llm.types.llm_types import CompletionLLM async def fine_tune( @@ -38,9 +38,20 @@ async def fine_tune( max_tokens: int = MAX_TOKEN_COUNT, chunk_size: int = MIN_CHUNK_SIZE, output: str = "prompts", - **kwargs, ): - """Fine tune the model.""" + """Fine tune the model. + + Parameters + ---------- + - root: The root directory. + - domain: The domain to map the input documents to. + - select: The chunk selection method. + - limit: The limit of chunks to load. + - max_tokens: The maximum number of tokens to use on entity extraction prompts. + - chunk_size: The chunk token size to use. + - output: The output folder to store the prompts. + """ + reporter = PrintProgressReporter("") config = read_config_parameters(root, reporter) output_path = Path(config.root_dir) / output @@ -63,13 +74,42 @@ async def fine_tune( config.llm.model_dump(), ) + await generate_indexing_prompts( + llm, config, doc_list, output_path, reporter, domain, max_tokens + ) + + +async def generate_indexing_prompts( + llm: CompletionLLM, + config: GraphRagConfig, + doc_list: list[str], + output_path: Path, + reporter: ProgressReporter, + domain: str | None = None, + max_tokens: int = MAX_TOKEN_COUNT, +): + """Generate indexing prompts. + + Parameters + ---------- + - llm: The LLM model to use. + - config: The GraphRag configuration. + - doc_list: The list of documents to use. + - output_path: The path to store the prompts. + - reporter: The progress reporter. + - domain: The domain to map the input documents to. + - max_tokens: The maximum number of tokens to use on entity extraction prompts + """ if not domain: + reporter.info("Generating domain...") domain = await generate_domain(llm, doc_list) - print(domain) + reporter.info(f"Generated domain: {domain}") + reporter.info("Generating persona...") persona = await generate_persona(llm, domain) - print(persona) + reporter.info(f"Generated persona: {persona}") + reporter.info("Generating entity types") entity_types = await generate_entity_types( llm, domain=domain, @@ -77,42 +117,49 @@ async def fine_tune( docs=doc_list, json_mode=config.llm.model_supports_json or False, ) - print(entity_types) + reporter.info(f"Generated entity types: {entity_types}") + reporter.info("Generating entity relationship examples...") examples = await generate_entity_relationship_examples( llm, persona=persona, entity_types=entity_types, docs=doc_list, - json_mode=config.llm.model_supports_json or False, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine ) - print(examples) + reporter.info("Done generating entity relationship examples") - prompt = create_entity_extraction_prompt( + reporter.info("Generating entity extraction prompt...") + create_entity_extraction_prompt( entity_types=entity_types, docs=doc_list, examples=examples, - json_mode=config.llm.model_supports_json or False, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine model_name=config.llm.model, output_path=output_path, max_token_count=max_tokens, ) + reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") - print(prompt) - - prompt = create_entity_summarization_prompt( + reporter.info("Generating entity summarization prompt...") + create_entity_summarization_prompt( persona=persona, output_path=output_path, ) + reporter.info( + f"Generated entity summarization prompt, stored in folder {output_path}" + ) + reporter.info("Generating community reporter role...") community_reporter_role = await generate_community_reporter_role( llm, domain=domain, persona=persona, docs=doc_list ) + reporter.info(f"Generated community reporter role: {community_reporter_role}") - print(community_reporter_role) - - prompt = create_community_summarization_prompt( + reporter.info("Generating community summarization prompt...") + create_community_summarization_prompt( persona=persona, role=community_reporter_role, output_path=output_path ) - - print(prompt) + reporter.info( + f"Generated community summarization prompt, stored in folder {output_path}" + ) diff --git a/graphrag/fine_tune/generator/__init__.py b/graphrag/fine_tune/generator/__init__.py index 13cb9e1c6d..e93277b47c 100644 --- a/graphrag/fine_tune/generator/__init__.py +++ b/graphrag/fine_tune/generator/__init__.py @@ -1,24 +1,26 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from .entity_relationship import generate_entity_relationship_examples -from .entity_types import generate_entity_types -from .persona import generate_persona +"""Prompt generation module.""" + +from .community_report_summarization import create_community_summarization_prompt +from .community_reporter_role import generate_community_reporter_role +from .defaults import MAX_TOKEN_COUNT from .domain import generate_domain from .entity_extraction_prompt import create_entity_extraction_prompt -from .defaults import MAX_TOKEN_COUNT +from .entity_relationship import generate_entity_relationship_examples from .entity_summarization_prompt import create_entity_summarization_prompt -from .community_reporter_role import generate_community_reporter_role -from .community_report_summarization import create_community_summarization_prompt - +from .entity_types import generate_entity_types +from .persona import generate_persona __all__ = [ - "generate_entity_relationship_examples", - "generate_entity_types", - "generate_persona", - "generate_domain", + "MAX_TOKEN_COUNT", + "create_community_summarization_prompt", "create_entity_extraction_prompt", "create_entity_summarization_prompt", "generate_community_reporter_role", - "MAX_TOKEN_COUNT", + "generate_domain", + "generate_entity_relationship_examples", + "generate_entity_types", + "generate_persona", ] diff --git a/graphrag/fine_tune/generator/community_report_summarization.py b/graphrag/fine_tune/generator/community_report_summarization.py index 94c5aa1d70..0fe39adf60 100644 --- a/graphrag/fine_tune/generator/community_report_summarization.py +++ b/graphrag/fine_tune/generator/community_report_summarization.py @@ -1,11 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Module for generating prompts for community report summarization.""" + from pathlib import Path -from graphrag.fine_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT +from graphrag.fine_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT -COMMUNITY_SUMMARIZATION_FILENAME = "community_report_summarization_prompt.txt" +COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt" def create_community_summarization_prompt( @@ -13,7 +15,18 @@ def create_community_summarization_prompt( role: str, output_path: Path | None = None, ) -> str: - + """Create a prompt for community summarization. If output_path is provided, write the prompt to a file. + + Parameters + ---------- + - persona (str): The persona to use for the community summarization prompt + - role (str): The role to use for the community summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + + Returns + ------- + - str: The community summarization prompt + """ prompt = COMMUNITY_REPORT_SUMMARIZATION_PROMPT.format(persona=persona, role=role) if output_path: diff --git a/graphrag/fine_tune/generator/community_reporter_role.py b/graphrag/fine_tune/generator/community_reporter_role.py index b7faf09300..264d3f1b18 100644 --- a/graphrag/fine_tune/generator/community_reporter_role.py +++ b/graphrag/fine_tune/generator/community_reporter_role.py @@ -1,8 +1,9 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Generate a community reporter role for community summarization.""" -from graphrag.fine_tune.prompt.community_reporter_role import ( +from graphrag.fine_tune.prompt import ( GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT, ) from graphrag.llm.types.llm_types import CompletionLLM @@ -11,8 +12,19 @@ async def generate_community_reporter_role( llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] ) -> str: - """Provided a community reporter role, generate an LLM persona to use for GraphRAG prompts""" + """Generate an LLM persona to use for GraphRAG prompts. + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - persona (str): The persona to generate a role for + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ docs_str = " ".join(docs) if isinstance(docs, list) else docs domain_prompt = GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT.format( domain=domain, persona=persona, input_text=docs_str diff --git a/graphrag/fine_tune/generator/defaults.py b/graphrag/fine_tune/generator/defaults.py index ff699332ce..5b42f81332 100644 --- a/graphrag/fine_tune/generator/defaults.py +++ b/graphrag/fine_tune/generator/defaults.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Default values for the fine-tuning module.""" + DEFAULT_TASK = """ Identify the relations and structure of the community of interest, specifically within the {domain} domain. """ diff --git a/graphrag/fine_tune/generator/domain.py b/graphrag/fine_tune/generator/domain.py index aa3323b35c..f655be8a23 100644 --- a/graphrag/fine_tune/generator/domain.py +++ b/graphrag/fine_tune/generator/domain.py @@ -1,14 +1,24 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Domain generation for GraphRAG prompts.""" from graphrag.fine_tune.prompt.domain import GENERATE_DOMAIN_PROMPT from graphrag.llm.types.llm_types import CompletionLLM async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str: - """Provided a domain and a task, generate an LLM persona to use for GraphRAG prompts""" + """Generate an LLM persona to use for GraphRAG prompts. + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ docs_str = " ".join(docs) if isinstance(docs, list) else docs domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str) diff --git a/graphrag/fine_tune/generator/entity_extraction_prompt.py b/graphrag/fine_tune/generator/entity_extraction_prompt.py index ed10f28c37..807c63e4a7 100644 --- a/graphrag/fine_tune/generator/entity_extraction_prompt.py +++ b/graphrag/fine_tune/generator/entity_extraction_prompt.py @@ -1,16 +1,18 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Entity Extraction prompt generator module.""" + from pathlib import Path + from graphrag.fine_tune.template import ( - GRAPH_EXTRACTION_PROMPT, - GRAPH_EXTRACTION_JSON_PROMPT, EXAMPLE_EXTRACTION_TEMPLATE, + GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, ) - from graphrag.index.utils.tokens import num_tokens_from_string -ENTITY_EXTRACTION_FILENAME = "entity_extraction_prompt.txt" +ENTITY_EXTRACTION_FILENAME = "entity_extraction.txt" def create_entity_extraction_prompt( @@ -22,7 +24,23 @@ def create_entity_extraction_prompt( json_mode: bool = False, output_path: Path | None = None, ) -> str: - + """ + Create a prompt for entity extraction. + + Parameters + ---------- + - entity_types (str | list[str]): The entity types to extract + - docs (list[str]): The list of documents to extract entities from + - examples (list[str]): The list of examples to use for entity extraction + - model_name (str): The name of the model to use for token counting + - max_token_count (int): The maximum number of tokens to use for the prompt + - json_mode (bool): Whether to use JSON mode for the prompt. Default is False + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + + Returns + ------- + - str: The entity extraction prompt + """ prompt = GRAPH_EXTRACTION_JSON_PROMPT if json_mode else GRAPH_EXTRACTION_PROMPT if isinstance(entity_types, list): entity_types = ", ".join(entity_types) @@ -42,9 +60,6 @@ def create_entity_extraction_prompt( n=i + 1, input_text=input, entity_types=entity_types, output=output ) - print( - f"Input tokens {num_tokens_from_string(input, model_name)}, output tokens {num_tokens_from_string(output, model_name)}" - ) example_tokens = num_tokens_from_string(example_formatted, model=model_name) # Squeeze in at least one example diff --git a/graphrag/fine_tune/generator/entity_relationship.py b/graphrag/fine_tune/generator/entity_relationship.py index fdd8a96d58..48e0c4b63c 100644 --- a/graphrag/fine_tune/generator/entity_relationship.py +++ b/graphrag/fine_tune/generator/entity_relationship.py @@ -1,11 +1,14 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Entity relationship example generation module.""" + import asyncio import json + from graphrag.fine_tune.prompt import ( - ENTITY_RELATIONSHIPS_GENERATION_PROMPT, ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, + ENTITY_RELATIONSHIPS_GENERATION_PROMPT, ) from graphrag.llm.types.llm_types import CompletionLLM @@ -17,8 +20,8 @@ async def generate_entity_relationship_examples( docs: str | list[str], json_mode: bool = False, ) -> list[str]: - """ - Generates a list of entity/relationships examples for use in generating an entity configuration. + """Generate a list of entity/relationships examples for use in generating an entity configuration. + Will return entity/relationships examples as either JSON or in tuple_delimiter format depending on the json_mode parameter. """ @@ -30,23 +33,20 @@ async def generate_entity_relationship_examples( history = [{"role": "system", "content": persona}] - messages = [] - for doc in docs_list: - messages.append( - ( - ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT - if json_mode - else ENTITY_RELATIONSHIPS_GENERATION_PROMPT - ).format(entity_types=entity_types_str, input_text=doc) - ) + messages = [ + ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT + if json_mode + else ENTITY_RELATIONSHIPS_GENERATION_PROMPT + ).format(entity_types=entity_types_str, input_text=doc) + for doc in docs_list + ] tasks = [llm(message, history=history, json=json_mode) for message in messages] responses = await asyncio.gather(*tasks) - examples = [ - json.dumps((response.json or "")) if json_mode else str(response.output) + return [ + json.dumps(response.json or "") if json_mode else str(response.output) for response in responses ] - - return examples diff --git a/graphrag/fine_tune/generator/entity_summarization_prompt.py b/graphrag/fine_tune/generator/entity_summarization_prompt.py index afa609777b..3d69c04627 100644 --- a/graphrag/fine_tune/generator/entity_summarization_prompt.py +++ b/graphrag/fine_tune/generator/entity_summarization_prompt.py @@ -1,19 +1,26 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Entity summarization prompt generation module.""" + from pathlib import Path -from graphrag.fine_tune.template import ENTITY_SUMMARIZATION_PROMPT -from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.fine_tune.template import ENTITY_SUMMARIZATION_PROMPT -ENTITY_SUMMARIZATION_FILENAME = "entity_summarization_prompt.txt" +ENTITY_SUMMARIZATION_FILENAME = "summarize_descriptions.txt" def create_entity_summarization_prompt( persona: str, output_path: Path | None = None, ) -> str: + """Create a prompt for entity summarization. If output_path is provided, write the prompt to a file. + Parameters + ---------- + - persona (str): The persona to use for the entity summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + """ prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona) if output_path: diff --git a/graphrag/fine_tune/generator/entity_types.py b/graphrag/fine_tune/generator/entity_types.py index 322f410f3e..0ae709d5ef 100644 --- a/graphrag/fine_tune/generator/entity_types.py +++ b/graphrag/fine_tune/generator/entity_types.py @@ -1,14 +1,14 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Entity type generation module for fine-tuning.""" -import json -from graphrag.llm.types.llm_types import CompletionLLM from graphrag.fine_tune.generator.defaults import DEFAULT_TASK from graphrag.fine_tune.prompt.entity_types import ( ENTITY_TYPE_GENERATION_JSON_PROMPT, ENTITY_TYPE_GENERATION_PROMPT, ) +from graphrag.llm.types.llm_types import CompletionLLM async def generate_entity_types( @@ -20,7 +20,7 @@ async def generate_entity_types( json_mode: bool = False, ) -> str | list[str]: """ - Generates entity type categories from a given set of (small) documents. + Generate entity type categories from a given set of (small) documents. Example Output: "entity_types": ['military unit', 'organization', 'person', 'location', 'event', 'date', 'equipment'] @@ -41,5 +41,5 @@ async def generate_entity_types( if json_mode: return (response.json or {}).get("entity_types", []) - else: - return str(response.output) + + return str(response.output) diff --git a/graphrag/fine_tune/generator/persona.py b/graphrag/fine_tune/generator/persona.py index ac4295bbc4..b3d189579f 100644 --- a/graphrag/fine_tune/generator/persona.py +++ b/graphrag/fine_tune/generator/persona.py @@ -1,16 +1,24 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.llm.types.llm_types import CompletionLLM +"""Persona generating module for fine-tuning GraphRAG prompts.""" + from graphrag.fine_tune.generator.defaults import DEFAULT_TASK from graphrag.fine_tune.prompt import GENERATE_PERSONA_PROMPT +from graphrag.llm.types.llm_types import CompletionLLM async def generate_persona( llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK ) -> str: - """Provided a domain and a task, generate an LLM persona to use for GraphRAG prompts""" + """Generate an LLM persona to use for GraphRAG prompts. + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - task (str): The task to generate a persona for. Default is DEFAULT_TASK + """ formatted_task = task.format(domain=domain) persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task) diff --git a/graphrag/fine_tune/loader/__init__.py b/graphrag/fine_tune/loader/__init__.py index e6ca4e18bf..6e431ff347 100644 --- a/graphrag/fine_tune/loader/__init__.py +++ b/graphrag/fine_tune/loader/__init__.py @@ -1,12 +1,14 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning condif and data loader module.""" + from .config import read_config_parameters -from .input import load_docs_in_chunks, MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE +from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks __all__ = [ - "read_config_parameters", - "load_docs_in_chunks", "MIN_CHUNK_OVERLAP", "MIN_CHUNK_SIZE", + "load_docs_in_chunks", + "read_config_parameters", ] diff --git a/graphrag/fine_tune/loader/config.py b/graphrag/fine_tune/loader/config.py index 119ab053ae..db451156ca 100644 --- a/graphrag/fine_tune/loader/config.py +++ b/graphrag/fine_tune/loader/config.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Config loading, parsing and handling module.""" from pathlib import Path @@ -9,6 +10,13 @@ def read_config_parameters(root: str, reporter: ProgressReporter): + """Read the configuration parameters from the settings file or environment variables. + + Parameters + ---------- + - root: The root directory where the parameters are. + - reporter: The progress reporter. + """ _root = Path(root) settings_yaml = _root / "settings.yaml" if not settings_yaml.exists(): diff --git a/graphrag/fine_tune/loader/input.py b/graphrag/fine_tune/loader/input.py index 6ca612a099..42f58c8092 100644 --- a/graphrag/fine_tune/loader/input.py +++ b/graphrag/fine_tune/loader/input.py @@ -1,15 +1,18 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Input loading module.""" + from typing import cast -from datashaper import NoopVerbCallbacks, TableContainer, VerbInput + import pandas as pd +from datashaper import NoopVerbCallbacks, TableContainer, VerbInput + from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input import load_input from graphrag.index.progress.types import ProgressReporter from graphrag.index.verbs import chunk - MIN_CHUNK_SIZE = 200 MIN_CHUNK_OVERLAP = 0 @@ -22,7 +25,7 @@ async def load_docs_in_chunks( reporter: ProgressReporter, chunk_size: int = MIN_CHUNK_SIZE, ) -> list[str]: - """Load docs for generating prompts.""" + """Load docs into chunks for generating prompts.""" dataset = await load_input(config.input, reporter, root) # covert to text units @@ -44,7 +47,7 @@ async def load_docs_in_chunks( dataset_chunks = cast(pd.DataFrame, dataset_chunks_table_container.table) # Select chunks into a new df and explode it - chunks_df = pd.DataFrame(dataset_chunks["chunks"].explode()) + chunks_df = pd.DataFrame(dataset_chunks["chunks"].explode()) # type: ignore # Depending on the select method, build the dataset if select_method == "top": @@ -53,6 +56,4 @@ async def load_docs_in_chunks( chunks_df = chunks_df.sample(n=limit) # Convert the dataset to list form, so we have a list of documents - doc_list = chunks_df["chunks"].tolist() - - return doc_list + return chunks_df["chunks"].tolist() diff --git a/graphrag/fine_tune/prompt/__init__.py b/graphrag/fine_tune/prompt/__init__.py index ee7a3d1094..e774b4ac68 100644 --- a/graphrag/fine_tune/prompt/__init__.py +++ b/graphrag/fine_tune/prompt/__init__.py @@ -1,23 +1,26 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from .persona import GENERATE_PERSONA_PROMPT -from .entity_types import ( - ENTITY_TYPE_GENERATION_JSON_PROMPT, - ENTITY_TYPE_GENERATION_PROMPT, -) +"""Persona, entity type, relationships and domain generation prompts module.""" + +from .community_reporter_role import GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT +from .domain import GENERATE_DOMAIN_PROMPT from .entity_relationship import ( ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, ENTITY_RELATIONSHIPS_GENERATION_PROMPT, ) - -from .domain import GENERATE_DOMAIN_PROMPT +from .entity_types import ( + ENTITY_TYPE_GENERATION_JSON_PROMPT, + ENTITY_TYPE_GENERATION_PROMPT, +) +from .persona import GENERATE_PERSONA_PROMPT __all__ = [ - "GENERATE_PERSONA_PROMPT", - "ENTITY_TYPE_GENERATION_JSON_PROMPT", - "ENTITY_TYPE_GENERATION_PROMPT", "ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT", "ENTITY_RELATIONSHIPS_GENERATION_PROMPT", + "ENTITY_TYPE_GENERATION_JSON_PROMPT", + "ENTITY_TYPE_GENERATION_PROMPT", + "GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT", "GENERATE_DOMAIN_PROMPT", + "GENERATE_PERSONA_PROMPT", ] diff --git a/graphrag/fine_tune/prompt/community_reporter_role.py b/graphrag/fine_tune/prompt/community_reporter_role.py index 6424a04306..b667bc2940 100644 --- a/graphrag/fine_tune/prompt/community_reporter_role.py +++ b/graphrag/fine_tune/prompt/community_reporter_role.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for community reporter role generation.""" + GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT = """ {persona} Given a sample text, help the user by creating a role definition that will be tasked with community analysis. @@ -8,9 +10,9 @@ Remember, your output should look just like the provided example in structure and content. Example: -A technologist reporter that is analyzing Kevin Scott's "Behind the Tech Podcast", given a list of entities -that belong to the community as well as their relationships and optional associated claims. -The report will be used to inform decision-makers about significant developments associated with the community and their potential impact. +A technologist reporter that is analyzing Kevin Scott's "Behind the Tech Podcast", given a list of entities +that belong to the community as well as their relationships and optional associated claims. +The report will be used to inform decision-makers about significant developments associated with the community and their potential impact. Domain: {domain} diff --git a/graphrag/fine_tune/prompt/domain.py b/graphrag/fine_tune/prompt/domain.py index 3de56b3c18..4b4587f8d8 100644 --- a/graphrag/fine_tune/prompt/domain.py +++ b/graphrag/fine_tune/prompt/domain.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for domain generation.""" + GENERATE_DOMAIN_PROMPT = """ You are an intelligent assistant that helps a human to analyze the information in a text document. Given a sample text, help the user by assigning a descriptive domain that summarizes what the text is about. diff --git a/graphrag/fine_tune/prompt/entity_relationship.py b/graphrag/fine_tune/prompt/entity_relationship.py index a68a943a95..d7b0692480 100644 --- a/graphrag/fine_tune/prompt/entity_relationship.py +++ b/graphrag/fine_tune/prompt/entity_relationship.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for entity relationship generation.""" + ENTITY_RELATIONSHIPS_GENERATION_PROMPT = """ -Goal- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. diff --git a/graphrag/fine_tune/prompt/entity_types.py b/graphrag/fine_tune/prompt/entity_types.py index cc121dcbb8..99b21db645 100644 --- a/graphrag/fine_tune/prompt/entity_types.py +++ b/graphrag/fine_tune/prompt/entity_types.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for entity types generation.""" + ENTITY_TYPE_GENERATION_PROMPT = """ The goal is to study the connections and relations between the entity types and their features in order to understand all available information from the text. The user's task is to {task}. diff --git a/graphrag/fine_tune/prompt/persona.py b/graphrag/fine_tune/prompt/persona.py index 8cb67b9b38..58515fd204 100644 --- a/graphrag/fine_tune/prompt/persona.py +++ b/graphrag/fine_tune/prompt/persona.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for persona generation.""" + GENERATE_PERSONA_PROMPT = """ You are an intelligent assistant that helps a human to analyze the information in a text document. Given a specific type of task and sample text, help the user by generating a 3 to 4 sentence description of an expert who could help solve the problem. diff --git a/graphrag/fine_tune/template/__init__.py b/graphrag/fine_tune/template/__init__.py index 38e9ade4d4..85ce2df8c5 100644 --- a/graphrag/fine_tune/template/__init__.py +++ b/graphrag/fine_tune/template/__init__.py @@ -1,21 +1,20 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for entity extraction, entity summarization, and community report summarization.""" + +from .community_report_summarization import COMMUNITY_REPORT_SUMMARIZATION_PROMPT from .entity_extraction import ( EXAMPLE_EXTRACTION_TEMPLATE, - GRAPH_EXTRACTION_PROMPT, GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, ) - from .entity_summarization import ENTITY_SUMMARIZATION_PROMPT -from .community_report_summarization import COMMUNITY_REPORT_SUMMARIZATION_PROMPT - - __all__ = [ + "COMMUNITY_REPORT_SUMMARIZATION_PROMPT", + "ENTITY_SUMMARIZATION_PROMPT", "EXAMPLE_EXTRACTION_TEMPLATE", - "GRAPH_EXTRACTION_PROMPT", "GRAPH_EXTRACTION_JSON_PROMPT", - "ENTITY_SUMMARIZATION_PROMPT", - "COMMUNITY_REPORT_SUMMARIZATION_PROMPT", + "GRAPH_EXTRACTION_PROMPT", ] diff --git a/graphrag/fine_tune/template/community_report_summarization.py b/graphrag/fine_tune/template/community_report_summarization.py index 7ef63fd72a..5d909fef9f 100644 --- a/graphrag/fine_tune/template/community_report_summarization.py +++ b/graphrag/fine_tune/template/community_report_summarization.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for community report summarization.""" + COMMUNITY_REPORT_SUMMARIZATION_PROMPT = """ {persona} @@ -10,7 +12,7 @@ # Report Structure The report should include the following sections: -- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. - SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant threats associated with its entities. - THREAT SEVERITY RATING: a float score between 0-10 that represents the potential global impact to humanity as posed by entities within the community. - RATING EXPLANATION: Give a single sentence explanation of the threat severity rating. diff --git a/graphrag/fine_tune/template/entity_extraction.py b/graphrag/fine_tune/template/entity_extraction.py index f9470d2fc0..0ac91d4b9f 100644 --- a/graphrag/fine_tune/template/entity_extraction.py +++ b/graphrag/fine_tune/template/entity_extraction.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for entity extraction.""" + GRAPH_EXTRACTION_PROMPT = """ -Goal- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. @@ -76,7 +78,7 @@ Example {n}: entity_types: [{entity_types}] -text: +text: {input_text} ------------------------ output: diff --git a/graphrag/fine_tune/template/entity_summarization.py b/graphrag/fine_tune/template/entity_summarization.py index f38ab4e45e..be926e1914 100644 --- a/graphrag/fine_tune/template/entity_summarization.py +++ b/graphrag/fine_tune/template/entity_summarization.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +"""Fine-tuning prompts for entity summarization.""" + ENTITY_SUMMARIZATION_PROMPT = """ {persona} Using your expertise, you're asked to generate a comprehensive summary of the data provided below. diff --git a/graphrag/index/input/load_input.py b/graphrag/index/input/load_input.py index 188cd24d1f..d00a604e76 100644 --- a/graphrag/index/input/load_input.py +++ b/graphrag/index/input/load_input.py @@ -10,7 +10,7 @@ import pandas as pd -from graphrag.config import StorageType, InputConfig +from graphrag.config import InputConfig, StorageType from graphrag.index.config import PipelineInputConfig from graphrag.index.progress import NullProgressReporter, ProgressReporter from graphrag.index.storage import ( diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index 7c45475df1..5787ccbd71 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -102,18 +102,16 @@ def _load_openai_completion_llm( azure=False, ): return _create_openai_completion_llm( - OpenAIConfiguration( - { - **_get_base_config(config), - "model": config.get("model", "gpt-4-turbo-preview"), - "deployment_name": config.get("deployment_name"), - "temperature": config.get("temperature", 0.0), - "frequency_penalty": config.get("frequency_penalty", 0), - "presence_penalty": config.get("presence_penalty", 0), - "top_p": config.get("top_p", 1), - "max_tokens": config.get("max_tokens"), - } - ), + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens"), + }), on_error, cache, azure, @@ -127,19 +125,17 @@ def _load_openai_chat_llm( azure=False, ): return _create_openai_chat_llm( - OpenAIConfiguration( - { - # Set default values - **_get_base_config(config), - "model": config.get("model", "gpt-4-turbo-preview"), - "deployment_name": config.get("deployment_name"), - "temperature": config.get("temperature", 0.0), - "frequency_penalty": config.get("frequency_penalty", 0), - "presence_penalty": config.get("presence_penalty", 0), - "top_p": config.get("top_p", 1), - "max_tokens": config.get("max_tokens"), - } - ), + OpenAIConfiguration({ + # Set default values + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens"), + }), on_error, cache, azure, @@ -154,15 +150,13 @@ def _load_openai_embeddings_llm( ): # TODO: Inject Cache return _create_openai_embeddings_llm( - OpenAIConfiguration( - { - **_get_base_config(config), - "model": config.get( - "embeddings_model", config.get("model", "text-embedding-3-small") - ), - "deployment_name": config.get("deployment_name"), - } - ), + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get( + "embeddings_model", config.get("model", "text-embedding-3-small") + ), + "deployment_name": config.get("deployment_name"), + }), on_error, cache, azure,