Skip to content

Commit

Permalink
Fix vector store logic and refactor audience parameter (#1259)
Browse files Browse the repository at this point in the history
  • Loading branch information
KennyZhang1 authored Oct 21, 2024
1 parent 6aae386 commit e0840a2
Show file tree
Hide file tree
Showing 27 changed files with 203 additions and 194 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Python Artifacts
python/*/lib/
dist/

# Test Output
.coverage
coverage/
Expand All @@ -20,7 +21,6 @@ venv/
.conda
.tmp


.env
build.zip

Expand Down
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241008161248831044.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "refactor use of vector stores and update support for managed identity"
}
14 changes: 11 additions & 3 deletions docs/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ This is the base LLM configuration section. Other steps may override this config
- `api_version` **str** - The API version
- `organization` **str** - The client organization.
- `proxy` **str** - The proxy URL to use.
- `cognitive_services_endpoint` **str** - The url endpoint for cognitive services.
- `audience` **str** - (Azure OpenAI only) The URI of the target Azure resource/service for which a managed identity token is requested. Used if `api_key` is not defined. Default=`https://cognitiveservices.azure.com/.default`
- `deployment_name` **str** - The deployment name to use (Azure).
- `model_supports_json` **bool** - Whether the model supports JSON-mode output.
- `tokens_per_minute` **int** - Set a leaky-bucket throttle on tokens-per-minute.
Expand Down Expand Up @@ -84,9 +84,17 @@ This is the base LLM configuration section. Other steps may override this config
- `parallelization` (see Parallelization top-level config)
- `async_mode` (see Async Mode top-level config)
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch #-tokens.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip.
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
- `url` **str** (only for AI Search) - AI Search endpoint
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
- `strategy` **dict** - Fully override the text-embedding strategy.

## chunks
Expand Down Expand Up @@ -214,7 +222,7 @@ This is the base LLM configuration section. Other steps may override this config

## encoding_model

**str** - The text encoding model to use. Default is `cl100k_base`.
**str** - The text encoding model to use. Default=`cl100k_base`.

## skip_workflows

Expand Down
6 changes: 3 additions & 3 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
Backwards compatibility is not guaranteed at this time.
"""

from .index_api import build_index
from .prompt_tune_api import DocSelectionType, generate_indexing_prompts
from .query_api import (
from graphrag.api.index import build_index
from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts
from graphrag.api.query import (
global_search,
global_search_streaming,
local_search,
Expand Down
13 changes: 12 additions & 1 deletion graphrag/api/index_api.py → graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
Backwards compatibility is not guaranteed at this time.
"""

from pathlib import Path

from graphrag.config import CacheType, GraphRagConfig
from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.emit.types import TableEmitterType
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging import ProgressReporter
from graphrag.vector_stores.factory import VectorStoreType


async def build_index(
Expand All @@ -30,7 +33,7 @@ async def build_index(
Parameters
----------
config : PipelineConfig
config : GraphRagConfig
The configuration.
run_id : str
The run id. Creates a output directory with this name.
Expand All @@ -55,6 +58,14 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore

pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
Expand Down
File renamed without changes.
97 changes: 26 additions & 71 deletions graphrag/api/query_api.py → graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from graphrag.config import GraphRagConfig
from graphrag.logging import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
Expand All @@ -35,10 +34,9 @@
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
from graphrag.utils.cli import redact
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType

reporter = PrintProgressReporter("")

Expand Down Expand Up @@ -184,24 +182,20 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")

vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)

lancedb_dir = Path(config.storage.base_dir) / "lancedb"

vector_store_args.update({"db_uri": str(lancedb_dir)})
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == "lancedb":
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
config_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []

search_engine = get_local_search_engine(
Expand Down Expand Up @@ -257,24 +251,20 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")

vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)

lancedb_dir = Path(config.storage.base_dir) / "lancedb"

vector_store_args.update({"db_uri": str(lancedb_dir)})
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
config_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []

search_engine = get_local_search_engine(
Expand Down Expand Up @@ -303,49 +293,14 @@ async def local_search_streaming(


def _get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
config_args: dict,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}

collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
vector_store_type = config_args["type"]
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)

description_embedding_store.connect(**config_args)

if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database

# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)

# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)

return description_embedding_store


Expand Down
32 changes: 13 additions & 19 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def hydrate_llm_params(
llm_type = LLMType(llm_type) if llm_type else base.type
api_key = reader.str(Fragment.api_key) or base.api_key
api_base = reader.str(Fragment.api_base) or base.api_base
cognitive_services_endpoint = (
reader.str(Fragment.cognitive_services_endpoint)
or base.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience) or base.audience
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
Expand Down Expand Up @@ -119,7 +116,7 @@ def hydrate_llm_params(
or base.model_supports_json,
request_timeout=reader.float(Fragment.request_timeout)
or base.request_timeout,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or base.tokens_per_minute,
Expand All @@ -141,7 +138,7 @@ def hydrate_embeddings_params(
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
api_key = reader.str(Fragment.api_key) or base.api_key

# In a unique events where:
# Account for various permutations of config settings such as:
# - same api_bases for LLM and embeddings (both Azure)
# - different api_bases for LLM and embeddings (both Azure)
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
Expand All @@ -158,10 +155,7 @@ def hydrate_embeddings_params(
)
api_organization = reader.str("organization") or base.organization
api_proxy = reader.str("proxy") or base.proxy
cognitive_services_endpoint = (
reader.str(Fragment.cognitive_services_endpoint)
or base.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)

if api_key is None and not _is_azure(api_type):
Expand All @@ -186,7 +180,7 @@ def hydrate_embeddings_params(
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
Expand Down Expand Up @@ -237,9 +231,7 @@ def hydrate_parallelization_params(
api_base = reader.str(Fragment.api_base) or fallback_oai_base
api_version = reader.str(Fragment.api_version) or fallback_oai_version
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
cognitive_services_endpoint = reader.str(
Fragment.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)

if api_key is None and not _is_azure(llm_type):
Expand Down Expand Up @@ -270,7 +262,7 @@ def hydrate_parallelization_params(
model_supports_json=reader.bool(Fragment.model_supports_json),
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int(Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
Expand All @@ -294,13 +286,15 @@ def hydrate_parallelization_params(
embeddings_config = values.get("embeddings") or {}
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
embeddings_target = reader.str("target")
# TODO: remove the type ignore annotations below once the new config engine has been refactored
embeddings_model = TextEmbeddingConfig(
llm=hydrate_embeddings_params(embeddings_config, llm_model),
llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore
parallelization=hydrate_parallelization_params(
embeddings_config, llm_parallelization_model
embeddings_config, # type: ignore
llm_parallelization_model, # type: ignore
),
vector_store=embeddings_config.get("vector_store", None),
async_mode=hydrate_async_type(embeddings_config, async_mode),
async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore
target=(
TextEmbeddingTarget(embeddings_target)
if embeddings_target
Expand Down Expand Up @@ -579,8 +573,8 @@ class Fragment(str, Enum):
api_organization = "API_ORGANIZATION"
api_proxy = "API_PROXY"
async_mode = "ASYNC_MODE"
audience = "AUDIENCE"
base_dir = "BASE_DIR"
cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT"
concurrent_requests = "CONCURRENT_REQUESTS"
conn_string = "CONNECTION_STRING"
container_name = "CONTAINER_NAME"
Expand Down
13 changes: 12 additions & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

"""Common default configuration values."""

from pathlib import Path

from datashaper import AsyncType

from graphrag.vector_stores import VectorStoreType

from .enums import (
CacheType,
InputFileType,
Expand Down Expand Up @@ -74,7 +78,7 @@
NODE2VEC_ITERATIONS = 3
NODE2VEC_RANDOM_SEED = 597832
REPORTING_TYPE = ReportingType.file
REPORTING_BASE_DIR = "output"
REPORTING_BASE_DIR = "logs"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
Expand All @@ -83,6 +87,13 @@
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
UMAP_ENABLED = False

VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value}
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: entity_description_embeddings
overwrite: true\
"""

# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
LOCAL_SEARCH_COMMUNITY_PROP = 0.1
Expand Down
Loading

0 comments on commit e0840a2

Please sign in to comment.