Skip to content

Commit

Permalink
use GRAPHRAG_BASE_ prefix for common vars; trim down common varset
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Apr 2, 2024
1 parent 2b162b0 commit 85baf5d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 71 deletions.
27 changes: 9 additions & 18 deletions javascript/docsite/_posts/_config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,15 @@ Our pipeline can ingest .csv or .txt data from an input folder. These files can

## Base LLM Settings

These settings control the base LLM arguments used by the pipeline, for both text-generation and embedding tasks. This is useful for configuring common connection, parallelization, and retry settings.

| Parameter | Description | Type | Required or Optional | Default Value |
| ------------------------------ | -------------------------------------------------------------------------------------- | ------- | -------------------- | ------------- |
| `GRAPHRAG_API_KEY` | The API key. (Note: `OPENAI_API_KEY is also used as a fallback) | `str` | required | `None` |
| `GRAPHRAG_API_BASE` | The API Base URL | `str` | required for AOAI | `None` |
| `GRAPHRAG_API_VERSION` | The AOAI API version. | `str` | required for AOAI | `None` |
| `GRAPHRAG_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_REQUEST_TIMEOUT` | The maximum number of seconds to wait for a response from the chat client. | `int` | optional | `180` |
| `GRAPHRAG_THREAD_COUNT` | The number of threads to use for LLM parallelization. | `int` | optional | 50 |
| `GRAPHRAG_THREAD_STAGGER` | The time to wait (in seconds) between starting each thread. | `float` | optional | 0.3 |
| `GRAPHRAG_CONCURRENT_REQUESTS` | The number of concurrent requests to allow for the embedding client. | `int` | optional | 25 |
| `GRAPHRAG_TPM` | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | optional | 0 |
| `GRAPHRAG_RPM` | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | optional | 0 |
| `GRAPHRAG_MAX_RETRIES` | The maximum number of retries to attempt when a request fails. | `int` | optional | 10 |
| `GRAPHRAG_MAX_RETRY_WAIT` | The maximum number of seconds to wait between retries. | `int` | optional | 10 |
| `GRAPHRAG_SLEEP_ON_RATE_LIMIT_RECOMMENDATION` | Whether to sleep on rate limit recommendation. (Azure Only) | `bool` | optional | `True` |
These settings control the base LLM arguments used by the pipeline. This is useful for API connection parameters.

| Parameter | Description | Type | Required or Optional | Default Value |
| ----------------------------------- | -------------------------------------------------------------------------------------- | ------- | -------------------- | ------------- |
| `GRAPHRAG_BASE_API_KEY` | The API key. (Note: `OPENAI_API_KEY is also used as a fallback) | `str` | required | `None` |
| `GRAPHRAG_BASE_API_BASE` | The API Base URL | `str` | required for AOAI | `None` |
| `GRAPHRAG_BASE_API_VERSION` | The AOAI API version. | `str` | required for AOAI | `None` |
| `GRAPHRAG_BASE_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_BASE_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |


## Text Generation Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
PipelineStorageType,
)
from graphrag.index.default_config.parameters.models import TextEmbeddingTarget
from graphrag.index.llm.types import LLMType

#
# LLM Parameters
#
DEFAULT_LLM_TYPE = "openai_chat"
DEFAULT_LLM_TYPE = LLMType.OpenAIChat
DEFAULT_LLM_MODEL = "gpt-4-turbo-preview"
DEFAULT_LLM_MAX_TOKENS = 4000
DEFAULT_LLM_REQUEST_TIMEOUT = 180.0
Expand All @@ -28,7 +29,7 @@
#
# Text Embedding Parameters
#
DEFAULT_EMBEDDING_TYPE = "openai_embedding"
DEFAULT_EMBEDDING_TYPE = LLMType.OpenAIEmbedding
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_TOKENS_PER_MINUTE = 0
DEFAULT_EMBEDDING_REQUESTS_PER_MINUTE = 0
Expand Down
127 changes: 77 additions & 50 deletions python/graphrag/graphrag/index/default_config/parameters/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from graphrag.index.default_config.parameters.models import TextEmbeddingTarget
from graphrag.index.default_config.parameters.read_dotenv import read_dotenv
from graphrag.index.llm.types import LLMType

from .default_config_parameters import DefaultConfigParametersDict
from .default_config_parameters_model import DefaultConfigParametersModel
Expand Down Expand Up @@ -89,6 +90,7 @@ class Fragment(str, Enum):
class Section(str, Enum):
"""Configuration Sections."""

base = "BASE"
cache = "CACHE"
chunk = "CHUNK"
claim_extraction = "CLAIM_EXTRACTION"
Expand All @@ -106,6 +108,28 @@ class Section(str, Enum):
umap = "UMAP"


LLM_KEY_REQUIRED = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_BASE_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
EMBEDDING_KEY_REQUIRED = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
AZURE_LLM_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_LLM_API_BASE_REQUIRED = (
"GRAPHRAG_BASE_API_BASE or GRAPHRAG_LLM_API_BASE is required for Azure OpenAI."
)
AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_EMBEDDING_API_BASE_REQUIRED = "GRAPHRAG_BASE_API_BASE or GRAPHRAG_EMBEDDING_API_BASE is required for Azure OpenAI."


def _is_azure(llm_type: LLMType) -> bool:
return (
llm_type == LLMType.AzureOpenAI
or llm_type == LLMType.AzureOpenAIChat
or llm_type == LLMType.AzureOpenAIEmbedding
)


def default_config_parameters_from_env_vars(
root_dir: str | None, resume_from: str | None = None
):
Expand Down Expand Up @@ -137,72 +161,77 @@ def section(key: Section):
fallback_oai_version = _str("OPENAI_API_VERSION")

with section(Section.graphrag):
_api_key = _str(Fragment.api_key, fallback_oai_key)
_api_base = _str(Fragment.api_base, fallback_oai_url)
_api_version = _str(Fragment.api_version, fallback_oai_version)
_organization = _str(Fragment.organization, fallback_oai_org)
_proxy = _str(Fragment.proxy)
_tpm = _int(Fragment.tpm)
_rpm = _int(Fragment.rpm)
_request_timeout = _float(Fragment.request_timeout)
_max_retries = _int(Fragment.max_retries)
_max_retry_wait = _float(Fragment.max_retry_wait)
_sleep_recommendation = _bool(Fragment.sleep_recommendation)
_concurrent_requests = _int(Fragment.concurrent_requests)
_async_mode = _str(Fragment.async_mode)
_stagger = _float(Fragment.thread_stagger)
_thread_count = _int(Fragment.thread_count)
with section(Section.base):
_api_key = _str(Fragment.api_key, fallback_oai_key)
_api_base = _str(Fragment.api_base, fallback_oai_url)
_api_version = _str(Fragment.api_version, fallback_oai_version)
_organization = _str(Fragment.organization, fallback_oai_org)
_proxy = _str(Fragment.proxy)

with section(Section.llm):
api_key = _str(Fragment.api_key, _api_key or fallback_oai_key)
if api_key is None:
msg = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
raise ValueError(msg)
raise ValueError(LLM_KEY_REQUIRED)
llm_type = _str(Fragment.type)
llm_type = LLMType(type)
deployment_name = str(Fragment.deployment_name)
is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)
if is_azure and deployment_name is None:
raise ValueError(AZURE_LLM_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_LLM_API_BASE_REQUIRED)

llm_parameters = LLMParametersModel(
api_key=api_key,
type=_str(Fragment.type),
type=llm_type,
model=_str(Fragment.model),
max_tokens=_int(Fragment.max_tokens),
model_supports_json=_bool(Fragment.model_supports_json),
request_timeout=_float(Fragment.request_timeout, _request_timeout),
api_base=_str(Fragment.api_base, _api_base),
request_timeout=_float(Fragment.request_timeout),
api_base=api_base,
api_version=_str(Fragment.api_version, _api_version),
organization=_str(Fragment.organization, _organization),
proxy=_str(Fragment.proxy, _proxy),
deployment_name=_str(Fragment.deployment_name),
tokens_per_minute=_int(Fragment.tpm, _tpm),
requests_per_minute=_int(Fragment.rpm, _rpm),
max_retries=_int(Fragment.max_retries, _max_retries),
max_retry_wait=_float(Fragment.max_retry_wait, _max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(
Fragment.sleep_recommendation, _sleep_recommendation
),
concurrent_requests=_int(
Fragment.concurrent_requests, _concurrent_requests
),
deployment_name=deployment_name,
tokens_per_minute=_int(Fragment.tpm),
requests_per_minute=_int(Fragment.rpm),
max_retries=_int(Fragment.max_retries),
max_retry_wait=_float(Fragment.max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(Fragment.sleep_recommendation),
concurrent_requests=_int(Fragment.concurrent_requests),
)
llm_parallelization = ParallelizationParametersModel(
stagger=_float(Fragment.thread_stagger, _stagger),
num_threads=_int(Fragment.thread_count, _thread_count),
stagger=_float(Fragment.thread_stagger),
num_threads=_int(Fragment.thread_count),
)

with section(Section.embedding):
api_key = _str(Fragment.api_key, _api_key)
if api_key is None:
msg = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
raise ValueError(msg)
raise ValueError(EMBEDDING_KEY_REQUIRED)

embedding_target = _str("TARGET")
embedding_target = (
TextEmbeddingTarget(embedding_target) if embedding_target else None
)
async_mode = _str(Fragment.async_mode, _async_mode)
async_mode = _str(Fragment.async_mode)
async_mode_enum = AsyncType(async_mode) if async_mode else None
deployment_name = _str(Fragment.deployment_name)
llm_type = _str(Fragment.type)
llm_type = LLMType(type)
is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)

if is_azure and deployment_name is None:
raise ValueError(AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_EMBEDDING_API_BASE_REQUIRED)

text_embeddings = TextEmbeddingConfigModel(
parallelization=ParallelizationParametersModel(
stagger=_float(Fragment.thread_stagger, _stagger),
num_threads=_int(Fragment.thread_count, _thread_count),
stagger=_float(Fragment.thread_stagger),
num_threads=_int(Fragment.thread_count),
),
async_mode=async_mode_enum,
target=embedding_target,
Expand All @@ -211,24 +240,22 @@ def section(key: Section):
skip=_array_string("SKIP"),
llm=LLMParametersModel(
api_key=_str(Fragment.api_key, _api_key),
type=_str(Fragment.type),
type=llm_type,
model=_str(Fragment.model),
request_timeout=_float(Fragment.request_timeout, _request_timeout),
api_base=_str(Fragment.api_base, _api_base),
request_timeout=_float(Fragment.request_timeout),
api_base=api_base,
api_version=_str(Fragment.api_version, _api_version),
organization=_str(Fragment.organization, _organization),
proxy=_str(Fragment.proxy, _proxy),
deployment_name=_str(Fragment.deployment_name),
tokens_per_minute=_int(Fragment.tpm, _tpm),
requests_per_minute=_int(Fragment.rpm, _rpm),
max_retries=_int(Fragment.max_retries, _max_retries),
max_retry_wait=_float(Fragment.max_retry_wait, _max_retry_wait),
deployment_name=deployment_name,
tokens_per_minute=_int(Fragment.tpm),
requests_per_minute=_int(Fragment.rpm),
max_retries=_int(Fragment.max_retries),
max_retry_wait=_float(Fragment.max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(
Fragment.sleep_recommendation, _sleep_recommendation
),
concurrent_requests=_int(
Fragment.concurrent_requests, _concurrent_requests
Fragment.sleep_recommendation
),
concurrent_requests=_int(Fragment.concurrent_requests),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, ConfigDict, Field

from graphrag.index.llm.types import LLMType


class LLMParametersModel(BaseModel):
"""LLM Parameters model."""
Expand All @@ -12,7 +14,9 @@ class LLMParametersModel(BaseModel):
api_key: str | None = Field(
description="The API key to use for the LLM service.", default=None
)
type: str | None = Field(description="The type of LLM model to use.", default=None)
type: LLMType | None = Field(
description="The type of LLM model to use.", default=None
)
model: str | None = Field(description="The LLM model to use.", default=None)
max_tokens: int | None = Field(
description="The maximum number of tokens to generate.", default=None
Expand Down

0 comments on commit 85baf5d

Please sign in to comment.