Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use GRAPHRAG_API_ prefix for shared, connection-based env-vars #14

Merged
merged 11 commits into from
Apr 2, 2024
27 changes: 9 additions & 18 deletions javascript/docsite/_posts/_config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,15 @@ Our pipeline can ingest .csv or .txt data from an input folder. These files can

## Base LLM Settings

These settings control the base LLM arguments used by the pipeline, for both text-generation and embedding tasks. This is useful for configuring common connection, parallelization, and retry settings.
These settings control the base LLM arguments used by the pipeline. This is useful for API connection parameters.

| Parameter | Description | Type | Required or Optional | Default Value |
| ------------------------------ | -------------------------------------------------------------------------------------- | ------- | -------------------- | ------------- |
| Parameter | Description | Type | Required or Optional | Default Value |
| ----------------------------------- | -------------------------------------------------------------------------------------- | ------- | -------------------- | ------------- |
| `GRAPHRAG_API_KEY` | The API key. (Note: `OPENAI_API_KEY is also used as a fallback) | `str` | required | `None` |
| `GRAPHRAG_API_BASE` | The API Base URL | `str` | required for AOAI | `None` |
| `GRAPHRAG_API_VERSION` | The AOAI API version. | `str` | required for AOAI | `None` |
| `GRAPHRAG_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_REQUEST_TIMEOUT` | The maximum number of seconds to wait for a response from the chat client. | `int` | optional | `180` |
| `GRAPHRAG_THREAD_COUNT` | The number of threads to use for LLM parallelization. | `int` | optional | 50 |
| `GRAPHRAG_THREAD_STAGGER` | The time to wait (in seconds) between starting each thread. | `float` | optional | 0.3 |
| `GRAPHRAG_CONCURRENT_REQUESTS` | The number of concurrent requests to allow for the embedding client. | `int` | optional | 25 |
| `GRAPHRAG_TPM` | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | optional | 0 |
| `GRAPHRAG_RPM` | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | optional | 0 |
| `GRAPHRAG_MAX_RETRIES` | The maximum number of retries to attempt when a request fails. | `int` | optional | 10 |
| `GRAPHRAG_MAX_RETRY_WAIT` | The maximum number of seconds to wait between retries. | `int` | optional | 10 |
| `GRAPHRAG_SLEEP_ON_RATE_LIMIT_RECOMMENDATION` | Whether to sleep on rate limit recommendation. (Azure Only) | `bool` | optional | `True` |
| `GRAPHRAG_API_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_API_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |


## Text Generation Settings
Expand All @@ -81,8 +72,8 @@ These settings control the text generation model used by the pipeline. These set
| `GRAPHRAG_LLM_API_KEY` | The API key. | `str` | required | `None` |
| `GRAPHRAG_LLM_API_BASE` | The API Base URL | `str` | required for AOAI | `None` |
| `GRAPHRAG_LLM_API_VERSION` | The AOAI API version. | `str` | required for AOAI | `None` |
| `GRAPHRAG_LLM_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_LLM_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_LLM_API_ORGANIZATION` | The AOAI organization. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_LLM_API_PROXY` | The AOAI proxy. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_LLM_DEPLOYMENT_NAME` | The AOAI deployment name. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_LLM_MODEL` | The model. | `str` | optional | `gpt-4-turbo-preview` |
| `GRAPHRAG_LLM_MAX_TOKENS` | The maximum number of tokens. | `int` | optional | `4000` |
Expand All @@ -107,8 +98,8 @@ These settings control the text embedding model used by the pipeline. These sett
| `GRAPHRAG_EMBEDDING_API_KEY` | The API key to use for the embedding client. | `str` | required | `None` |
| `GRAPHRAG_EMBEDDING_API_BASE` | The API base URL. | `str` | required for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_API_VERSION` | The AOAI API version to use for the embedding client. | `str` | required for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_ORGANIZATION` | The AOAI organization to use for the embedding client. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_PROXY` | The AOAI proxy to use for the embedding client. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_API_ORGANIZATION` | The AOAI organization to use for the embedding client. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_API_PROXY` | The AOAI proxy to use for the embedding client. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME` | The AOAI deployment name. | `str` | optional for AOAI | `None` |
| `GRAPHRAG_EMBEDDING_MODEL` | The model to use for the embedding client. | `str` | optional | `text-embedding-3-small` |
| `GRAPHRAG_EMBEDDING_BATCH_SIZE` | The number of texts to embed at once. [(Azure limit is 16)]( https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) | `int` | optional | 16 |
Expand Down
5 changes: 5 additions & 0 deletions javascript/docsite/_posts/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ Next we'll inject some required config variables:
```sh
echo "GRAPHRAG_API_KEY=\"<Your OpenAI API Key>\"" >> ./ragtest/.env
echo "GRAPHRAG_INPUT_TYPE=text" >> ./ragtest/.env

# For Azure OpenAI Users
echo "GRAPHRAG_API_BASE=http://<domain>.openai.azure.com" >> ./ragtest/.env
echo "GRAPHRAG_LLM_DEPLOYMENT_NAME"="gpt-4" >> ./ragtest/.env
echo "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME"="text-embedding-3-small" >> ./ragtest/.env
```

Finally we'll run the pipeline!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
PipelineStorageType,
)
from graphrag.index.default_config.parameters.models import TextEmbeddingTarget
from graphrag.index.llm.types import LLMType

#
# LLM Parameters
#
DEFAULT_LLM_TYPE = "openai_chat"
DEFAULT_LLM_TYPE = LLMType.OpenAIChat
DEFAULT_LLM_MODEL = "gpt-4-turbo-preview"
DEFAULT_LLM_MAX_TOKENS = 4000
DEFAULT_LLM_REQUEST_TIMEOUT = 180.0
Expand All @@ -28,7 +29,7 @@
#
# Text Embedding Parameters
#
DEFAULT_EMBEDDING_TYPE = "openai_embedding"
DEFAULT_EMBEDDING_TYPE = LLMType.OpenAIEmbedding
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_TOKENS_PER_MINUTE = 0
DEFAULT_EMBEDDING_REQUESTS_PER_MINUTE = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from graphrag.index.default_config.parameters.models import TextEmbeddingTarget
from graphrag.index.default_config.parameters.read_dotenv import read_dotenv
from graphrag.index.llm.types import LLMType

from .default_config_parameters import DefaultConfigParametersDict
from .default_config_parameters_model import DefaultConfigParametersModel
Expand Down Expand Up @@ -57,6 +58,8 @@ class Fragment(str, Enum):
api_base = "API_BASE"
api_key = "API_KEY"
api_version = "API_VERSION"
api_organization = "API_ORGANIZATION"
api_proxy = "API_PROXY"
async_mode = "ASYNC_MODE"
concurrent_requests = "CONCURRENT_REQUESTS"
conn_string = "CONNECTION_STRING"
Expand All @@ -73,9 +76,7 @@ class Fragment(str, Enum):
max_tokens = "MAX_TOKENS"
model = "MODEL"
model_supports_json = "MODEL_SUPPORTS_JSON"
organization = "ORGANIZATION"
prompt_file = "PROMPT_FILE"
proxy = "PROXY"
request_timeout = "REQUEST_TIMEOUT"
rpm = "RPM"
sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION"
Expand All @@ -89,6 +90,7 @@ class Fragment(str, Enum):
class Section(str, Enum):
"""Configuration Sections."""

base = "BASE"
cache = "CACHE"
chunk = "CHUNK"
claim_extraction = "CLAIM_EXTRACTION"
Expand All @@ -106,6 +108,26 @@ class Section(str, Enum):
umap = "UMAP"


LLM_KEY_REQUIRED = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_BASE_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
EMBEDDING_KEY_REQUIRED = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
AZURE_LLM_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_LLM_MODEL or GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_LLM_API_BASE_REQUIRED = (
"GRAPHRAG_BASE_API_BASE or GRAPHRAG_LLM_API_BASE is required for Azure OpenAI."
)
AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED = "GRAPHRAG_EMBEDDING_MODEL or GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI."
AZURE_EMBEDDING_API_BASE_REQUIRED = "GRAPHRAG_BASE_API_BASE or GRAPHRAG_EMBEDDING_API_BASE is required for Azure OpenAI."


def _is_azure(llm_type: LLMType | None) -> bool:
return (
llm_type == LLMType.AzureOpenAI
or llm_type == LLMType.AzureOpenAIChat
or llm_type == LLMType.AzureOpenAIEmbedding
)


def default_config_parameters_from_env_vars(
root_dir: str | None, resume_from: str | None = None
):
Expand Down Expand Up @@ -140,69 +162,76 @@ def section(key: Section):
_api_key = _str(Fragment.api_key, fallback_oai_key)
_api_base = _str(Fragment.api_base, fallback_oai_url)
_api_version = _str(Fragment.api_version, fallback_oai_version)
_organization = _str(Fragment.organization, fallback_oai_org)
_proxy = _str(Fragment.proxy)
_tpm = _int(Fragment.tpm)
_rpm = _int(Fragment.rpm)
_request_timeout = _float(Fragment.request_timeout)
_max_retries = _int(Fragment.max_retries)
_max_retry_wait = _float(Fragment.max_retry_wait)
_sleep_recommendation = _bool(Fragment.sleep_recommendation)
_concurrent_requests = _int(Fragment.concurrent_requests)
_async_mode = _str(Fragment.async_mode)
_stagger = _float(Fragment.thread_stagger)
_thread_count = _int(Fragment.thread_count)
_organization = _str(Fragment.api_organization, fallback_oai_org)
_proxy = _str(Fragment.api_proxy)

with section(Section.llm):
api_key = _str(Fragment.api_key, _api_key or fallback_oai_key)
if api_key is None:
msg = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
raise ValueError(msg)
raise ValueError(LLM_KEY_REQUIRED)
llm_type = _str(Fragment.type)
llm_type = LLMType(llm_type) if llm_type else None
deployment_name = str(Fragment.deployment_name)
model = _str(Fragment.model)

is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)
if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_LLM_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_LLM_API_BASE_REQUIRED)

llm_parameters = LLMParametersModel(
api_key=api_key,
type=_str(Fragment.type),
model=_str(Fragment.model),
type=llm_type,
model=model,
max_tokens=_int(Fragment.max_tokens),
model_supports_json=_bool(Fragment.model_supports_json),
request_timeout=_float(Fragment.request_timeout, _request_timeout),
api_base=_str(Fragment.api_base, _api_base),
request_timeout=_float(Fragment.request_timeout),
api_base=api_base,
api_version=_str(Fragment.api_version, _api_version),
organization=_str(Fragment.organization, _organization),
proxy=_str(Fragment.proxy, _proxy),
deployment_name=_str(Fragment.deployment_name),
tokens_per_minute=_int(Fragment.tpm, _tpm),
requests_per_minute=_int(Fragment.rpm, _rpm),
max_retries=_int(Fragment.max_retries, _max_retries),
max_retry_wait=_float(Fragment.max_retry_wait, _max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(
Fragment.sleep_recommendation, _sleep_recommendation
),
concurrent_requests=_int(
Fragment.concurrent_requests, _concurrent_requests
),
organization=_str(Fragment.api_organization, _organization),
proxy=_str(Fragment.api_proxy, _proxy),
deployment_name=deployment_name,
tokens_per_minute=_int(Fragment.tpm),
requests_per_minute=_int(Fragment.rpm),
max_retries=_int(Fragment.max_retries),
max_retry_wait=_float(Fragment.max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(Fragment.sleep_recommendation),
concurrent_requests=_int(Fragment.concurrent_requests),
)
llm_parallelization = ParallelizationParametersModel(
stagger=_float(Fragment.thread_stagger, _stagger),
num_threads=_int(Fragment.thread_count, _thread_count),
stagger=_float(Fragment.thread_stagger),
num_threads=_int(Fragment.thread_count),
)

with section(Section.embedding):
api_key = _str(Fragment.api_key, _api_key)
if api_key is None:
msg = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
raise ValueError(msg)
raise ValueError(EMBEDDING_KEY_REQUIRED)

embedding_target = _str("TARGET")
embedding_target = (
TextEmbeddingTarget(embedding_target) if embedding_target else None
)
async_mode = _str(Fragment.async_mode, _async_mode)
async_mode = _str(Fragment.async_mode)
async_mode_enum = AsyncType(async_mode) if async_mode else None
deployment_name = _str(Fragment.deployment_name)
model = _str(Fragment.model)
llm_type = _str(Fragment.type)
llm_type = LLMType(llm_type) if llm_type else None
is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)

if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_EMBEDDING_API_BASE_REQUIRED)

text_embeddings = TextEmbeddingConfigModel(
parallelization=ParallelizationParametersModel(
stagger=_float(Fragment.thread_stagger, _stagger),
num_threads=_int(Fragment.thread_count, _thread_count),
stagger=_float(Fragment.thread_stagger),
num_threads=_int(Fragment.thread_count),
),
async_mode=async_mode_enum,
target=embedding_target,
Expand All @@ -211,24 +240,22 @@ def section(key: Section):
skip=_array_string("SKIP"),
llm=LLMParametersModel(
api_key=_str(Fragment.api_key, _api_key),
type=_str(Fragment.type),
model=_str(Fragment.model),
request_timeout=_float(Fragment.request_timeout, _request_timeout),
api_base=_str(Fragment.api_base, _api_base),
type=llm_type,
model=model,
request_timeout=_float(Fragment.request_timeout),
api_base=api_base,
api_version=_str(Fragment.api_version, _api_version),
organization=_str(Fragment.organization, _organization),
proxy=_str(Fragment.proxy, _proxy),
deployment_name=_str(Fragment.deployment_name),
tokens_per_minute=_int(Fragment.tpm, _tpm),
requests_per_minute=_int(Fragment.rpm, _rpm),
max_retries=_int(Fragment.max_retries, _max_retries),
max_retry_wait=_float(Fragment.max_retry_wait, _max_retry_wait),
organization=_str(Fragment.api_organization, _organization),
proxy=_str(Fragment.api_proxy, _proxy),
deployment_name=deployment_name,
tokens_per_minute=_int(Fragment.tpm),
requests_per_minute=_int(Fragment.rpm),
max_retries=_int(Fragment.max_retries),
max_retry_wait=_float(Fragment.max_retry_wait),
sleep_on_rate_limit_recommendation=_bool(
Fragment.sleep_recommendation, _sleep_recommendation
),
concurrent_requests=_int(
Fragment.concurrent_requests, _concurrent_requests
Fragment.sleep_recommendation
),
concurrent_requests=_int(Fragment.concurrent_requests),
),
)

Expand Down
Loading
Loading