Skip to content

Commit

Permalink
Add encoding model to entity/claim extraction config sections (#740)
Browse files Browse the repository at this point in the history
* Add encoding-model configuration to entity & claim extraction

* add change note

* pr updates

* test fix

* disable GH-based smoke tests
  • Loading branch information
darthtrevino authored Jul 26, 2024
1 parent 8565cd6 commit 9d99f32
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 25 deletions.
30 changes: 15 additions & 15 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,18 @@ jobs:
run: |
poetry run poe test_integration
- name: Smoke Test
if: steps.changes.outputs.python == 'true'
run: |
poetry run poe test_smoke
- uses: actions/upload-artifact@v4
if: always()
with:
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
path: tests/fixtures/*/output

- name: E2E Test
if: steps.changes.outputs.python == 'true'
run: |
./scripts/e2e-test.sh
# - name: Smoke Test
# if: steps.changes.outputs.python == 'true'
# run: |
# poetry run poe test_smoke

# - uses: actions/upload-artifact@v4
# if: always()
# with:
# name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
# path: tests/fixtures/*/output

# - name: E2E Test
# if: steps.changes.outputs.python == 'true'
# run: |
# ./scripts/e2e-test.sh
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240726181256417715.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add encoding-model to entity/claim extraction config"
}
14 changes: 8 additions & 6 deletions docsite/posts/config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ These settings control the data input used by the pipeline. Any settings with a

## Data Chunking

| Parameter | Description | Type | Required or Optional | Default |
| --------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ------- |
| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 |
| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 |
| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` |
| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | `None` |
| Parameter | Description | Type | Required or Optional | Default |
| ------------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ----------------------------- |
| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 |
| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 |
| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` |
| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | The top-level encoding model. |

## Prompting Overrides

Expand All @@ -146,12 +146,14 @@ These settings control the data input used by the pipeline. Any settings with a
| `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` | The path (relative to the root) of an entity extraction prompt template text file. | `str` | optional | `None` |
| `GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting entities in a loop. | `int` | optional | 1 |
| `GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES` | A comma-separated list of entity types to extract. | `str` | optional | `organization,person,event,geo` |
| `GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL` | The encoding model to use for entity extraction. | `str` | optional | The top-level encoding model. |
| `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` | The path (relative to the root) of an description summarization prompt template text file. | `str` | optional | `None` |
| `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH` | The maximum number of tokens to generate per description summarization. | `int` | optional | 500 |
| `GRAPHRAG_CLAIM_EXTRACTION_ENABLED` | Whether claim extraction is enabled for this pipeline. | `bool` | optional | `False` |
| `GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION` | The claim_description prompting argument to utilize. | `string` | optional | "Any claims or facts that could be relevant to threat analysis." |
| `GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE` | The claim extraction prompt to utilize. | `string` | optional | `None` |
| `GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting claims in a loop. | `int` | optional | 1 |
| `GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL` | The encoding model to use for claim extraction. | `str` | optional | The top-level encoding model |
| `GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE` | The community reports extraction prompt to utilize. | `string` | optional | `None` |
| `GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH` | The maximum number of tokens to generate per community reports. | `int` | optional | 1500 |

Expand Down
2 changes: 2 additions & 0 deletions docsite/posts/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ This is the base LLM configuration section. Other steps may override this config
- `prompt` **str** - The prompt file to use.
- `entity_types` **list[str]** - The entity types to identify.
- `max_gleanings` **int** - The maximum number of gleaning cycles to use.
- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model.
- `strategy` **dict** - Fully override the entity extraction strategy.

## summarize_descriptions
Expand All @@ -169,6 +170,7 @@ This is the base LLM configuration section. Other steps may override this config
- `prompt` **str** - The prompt file to use.
- `description` **str** - Describes the types of claims we want to extract.
- `max_gleanings` **int** - The maximum number of gleaning cycles to use.
- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model.
- `strategy` **dict** - Fully override the claim extraction strategy.

## community_reports
Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def hydrate_parallelization_params(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=reader.str(Fragment.encoding_model),
)
with (
reader.envvar_prefix(Section.snapshot),
Expand Down Expand Up @@ -428,6 +429,7 @@ def hydrate_parallelization_params(
or defs.ENTITY_EXTRACTION_ENTITY_TYPES,
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
encoding_model=reader.str(Fragment.encoding_model),
)

claim_extraction_config = values.get("claim_extraction") or {}
Expand All @@ -449,6 +451,7 @@ def hydrate_parallelization_params(
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=reader.str(Fragment.encoding_model),
)

community_report_config = values.get("community_reports") or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ class ClaimExtractionConfigInput(LLMConfigInput):
description: NotRequired[str | None]
max_gleanings: NotRequired[int | str | None]
strategy: NotRequired[dict | None]
encoding_model: NotRequired[str | None]
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ class EntityExtractionConfigInput(LLMConfigInput):
entity_types: NotRequired[list[str] | str | None]
max_gleanings: NotRequired[int | str | None]
strategy: NotRequired[dict | None]
encoding_model: NotRequired[str | None]
6 changes: 5 additions & 1 deletion graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class ClaimExtractionConfig(LLMConfig):
strategy: dict | None = Field(
description="The override strategy to use.", default=None
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.verbs.covariates.extract_covariates import (
ExtractClaimsStrategyType,
Expand All @@ -50,4 +53,5 @@ def resolved_strategy(self, root_dir: str) -> dict:
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": self.encoding_model or encoding_model,
}
5 changes: 4 additions & 1 deletion graphrag/config/models/entity_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class EntityExtractionConfig(LLMConfig):
strategy: dict | None = Field(
description="Override the default entity extraction strategy", default=None
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved entity extraction strategy."""
Expand All @@ -45,6 +48,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": encoding_model,
"encoding_name": self.encoding_model or encoding_model,
"prechunked": True,
}
2 changes: 1 addition & 1 deletion graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _covariate_workflows(
"claim_extract": {
**settings.claim_extraction.parallelization.model_dump(),
"strategy": settings.claim_extraction.resolved_strategy(
settings.root_dir
settings.root_dir, settings.encoding_model
),
},
},
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/config/test_default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123",
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000",
"GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt",
"GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL": "encoding_a",
"GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456",
"GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt",
"GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17",
Expand All @@ -115,6 +116,7 @@
"GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant",
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112",
"GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt",
"GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL": "encoding_b",
"GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir",
"GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs",
"GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn",
Expand Down Expand Up @@ -543,6 +545,7 @@ def test_create_parameters_from_env_vars(self) -> None:
assert parameters.claim_extraction.description == "test 123"
assert parameters.claim_extraction.max_gleanings == 5000
assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt"
assert parameters.claim_extraction.encoding_model == "encoding_a"
assert parameters.cluster_graph.max_cluster_size == 123
assert parameters.community_reports.max_length == 23456
assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt"
Expand Down Expand Up @@ -572,6 +575,7 @@ def test_create_parameters_from_env_vars(self) -> None:
assert parameters.entity_extraction.llm.api_base == "http://some/base"
assert parameters.entity_extraction.max_gleanings == 112
assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt"
assert parameters.entity_extraction.encoding_model == "encoding_b"
assert parameters.input.storage_account_blob_url == "input_account_blob_url"
assert parameters.input.base_dir == "/some/input/dir"
assert parameters.input.connection_string == "input_cs"
Expand Down Expand Up @@ -910,7 +914,7 @@ def test_prompt_file_reading(self):
assert strategy["extraction_prompt"] == "Hello, World! A"
assert strategy["encoding_name"] == "abc123"

strategy = config.claim_extraction.resolved_strategy(".")
strategy = config.claim_extraction.resolved_strategy(".", "encoding_b")
assert strategy["extraction_prompt"] == "Hello, World! B"

strategy = config.community_reports.resolved_strategy(".")
Expand Down

0 comments on commit 9d99f32

Please sign in to comment.