Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ollama #1326

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
@@ -98,14 +98,17 @@ class LLMType(str, Enum):
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
OllamaEmbedding = "ollama_embedding"

# Raw Completion
OpenAI = "openai"
AzureOpenAI = "azure_openai"
Ollama = "ollama"

# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
OllamaChat = "ollama_chat"

# Debug
StaticResponse = "static_response"
109 changes: 106 additions & 3 deletions graphrag/index/llm/load_llm.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,12 @@
create_openai_completion_llm,
create_openai_embedding_llm,
create_tpm_rpm_limiters,
OllamaConfiguration,
create_ollama_client,
create_ollama_chat_llm,
create_ollama_embedding_llm,
create_ollama_completion_llm,
LLMConfig,
)

if TYPE_CHECKING:
@@ -46,7 +52,6 @@ def load_llm(
) -> CompletionLLM:
"""Load the LLM for the entity extraction chain."""
on_error = _create_error_handler(callbacks)

if llm_type in loaders:
if chat_only and not loaders[llm_type]["chat"]:
msg = f"LLM type {llm_type} does not support chat"
@@ -182,6 +187,50 @@ def _load_azure_openai_embeddings_llm(
return _load_openai_embeddings_llm(on_error, cache, config, True)


def _load_ollama_completion_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
return _create_ollama_completion_llm(
OllamaConfiguration({
**_get_base_config(config),
}),
on_error,
cache,
)


def _load_ollama_chat_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
return _create_ollama_chat_llm(
OllamaConfiguration({
# Set default values
**_get_base_config(config),
}),
on_error,
cache,
)


def _load_ollama_embeddings_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
# TODO: Inject Cache
return _create_ollama_embeddings_llm(
OllamaConfiguration({
**_get_base_config(config),
}),
on_error,
cache,
)


def _get_base_config(config: dict[str, Any]) -> dict[str, Any]:
api_key = config.get("api_key")

@@ -218,6 +267,10 @@ def _load_static_response(
"load": _load_azure_openai_completion_llm,
"chat": False,
},
LLMType.Ollama: {
"load": _load_ollama_completion_llm,
"chat": False,
},
LLMType.OpenAIChat: {
"load": _load_openai_chat_llm,
"chat": True,
@@ -226,6 +279,10 @@ def _load_static_response(
"load": _load_azure_openai_chat_llm,
"chat": True,
},
LLMType.OllamaChat: {
"load": _load_ollama_chat_llm,
"chat": True,
},
LLMType.OpenAIEmbedding: {
"load": _load_openai_embeddings_llm,
"chat": False,
@@ -234,6 +291,10 @@ def _load_static_response(
"load": _load_azure_openai_embeddings_llm,
"chat": False,
},
LLMType.OllamaEmbedding: {
"load": _load_ollama_embeddings_llm,
"chat": False,
},
LLMType.StaticResponse: {
"load": _load_static_response,
"chat": False,
@@ -286,7 +347,49 @@ def _create_openai_embeddings_llm(
)


def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter:
def _create_ollama_chat_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> CompletionLLM:
"""Create an Ollama chat llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_chat_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_ollama_completion_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> CompletionLLM:
"""Create an Ollama completion llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_completion_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_ollama_embeddings_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> EmbeddingLLM:
"""Create an Ollama embeddings llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_embedding_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_limiter(configuration: LLMConfig) -> LLMLimiter:
limit_name = configuration.model or configuration.deployment_name or "default"
if limit_name not in _rate_limiters:
tpm = configuration.tokens_per_minute
@@ -296,7 +399,7 @@ def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter:
return _rate_limiters[limit_name]


def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None:
def _create_semaphore(configuration: LLMConfig) -> asyncio.Semaphore | None:
limit_name = configuration.model or configuration.deployment_name or "default"
concurrency = configuration.concurrent_requests

21 changes: 21 additions & 0 deletions graphrag/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,17 @@
LLMOutput,
OnCacheActionFn,
)
from .ollama import (
OllamaChatLLM,
OllamaClientType,
OllamaCompletionLLM,
OllamaConfiguration,
OllamaEmbeddingsLLM,
create_ollama_chat_llm,
create_ollama_client,
create_ollama_completion_llm,
create_ollama_embedding_llm,
)

__all__ = [
# LLM Types
@@ -79,13 +90,23 @@
"OpenAIConfiguration",
"OpenAIEmbeddingsLLM",
"RateLimitingLLM",
# Ollama
"OllamaChatLLM",
"OllamaClientType",
"OllamaCompletionLLM",
"OllamaConfiguration",
"OllamaEmbeddingsLLM",
# Errors
"RetriesExhaustedError",
"TpmRpmLLMLimiter",
"create_openai_chat_llm",
"create_openai_client",
"create_openai_completion_llm",
"create_openai_embedding_llm",
"create_ollama_chat_llm",
"create_ollama_client",
"create_ollama_completion_llm",
"create_ollama_embedding_llm",
# Limiters
"create_tpm_rpm_limiters",
]
29 changes: 29 additions & 0 deletions graphrag/llm/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Ollama LLM implementations."""

from .create_ollama_client import create_ollama_client
from .factories import (
create_ollama_chat_llm,
create_ollama_completion_llm,
create_ollama_embedding_llm,
)
from .ollama_chat_llm import OllamaChatLLM
from .ollama_completion_llm import OllamaCompletionLLM
from .ollama_configuration import OllamaConfiguration
from .ollama_embeddings_llm import OllamaEmbeddingsLLM
from .types import OllamaClientType


__all__ = [
"OllamaChatLLM",
"OllamaClientType",
"OllamaCompletionLLM",
"OllamaConfiguration",
"OllamaEmbeddingsLLM",
"create_ollama_chat_llm",
"create_ollama_client",
"create_ollama_completion_llm",
"create_ollama_embedding_llm",
]
36 changes: 36 additions & 0 deletions graphrag/llm/ollama/create_ollama_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Create OpenAI client instance."""

import logging
from functools import cache

from ollama import AsyncClient, Client

from .ollama_configuration import OllamaConfiguration
from .types import OllamaClientType

log = logging.getLogger(__name__)

API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client"


@cache
def create_ollama_client(
configuration: OllamaConfiguration,
sync: bool = False,
) -> OllamaClientType:
"""Create a new Ollama client instance."""

log.info("Creating OpenAI client base_url=%s", configuration.api_base)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi, "Creating OpenAI client base_url=%s" should be "Creating Ollama client base_url=%s"

if sync:
return Client(
host=configuration.api_base,
timeout=configuration.request_timeout or 180.0,
)
return AsyncClient(
host=configuration.api_base,
# Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here
timeout=configuration.request_timeout or 180.0,
)
139 changes: 139 additions & 0 deletions graphrag/llm/ollama/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Factory functions for creating OpenAI LLMs."""

import asyncio

from graphrag.llm.base import CachingLLM, RateLimitingLLM
from graphrag.llm.limiting import LLMLimiter
from graphrag.llm.types import (
LLM,
CompletionLLM,
EmbeddingLLM,
ErrorHandlerFn,
LLMCache,
LLMInvocationFn,
OnCacheActionFn,
)
from graphrag.llm.utils import (
RATE_LIMIT_ERRORS,
RETRYABLE_ERRORS,
get_sleep_time_from_error,
get_token_counter,
)
from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM
from graphrag.llm.openai.openai_token_replacing_llm import OpenAITokenReplacingLLM

from .json_parsing_llm import JsonParsingLLM
from .ollama_chat_llm import OllamaChatLLM
from .ollama_completion_llm import OllamaCompletionLLM
from .ollama_configuration import OllamaConfiguration
from .ollama_embeddings_llm import OllamaEmbeddingsLLM
from .types import OllamaClientType


def create_ollama_chat_llm(
client: OllamaClientType,
config: OllamaConfiguration,
cache: LLMCache | None = None,
limiter: LLMLimiter | None = None,
semaphore: asyncio.Semaphore | None = None,
on_invoke: LLMInvocationFn | None = None,
on_error: ErrorHandlerFn | None = None,
on_cache_hit: OnCacheActionFn | None = None,
on_cache_miss: OnCacheActionFn | None = None,
) -> CompletionLLM:
"""Create an OpenAI chat LLM."""
operation = "chat"
result = OllamaChatLLM(client, config)
result.on_error(on_error)
if limiter is not None or semaphore is not None:
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke)
if cache is not None:
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss)
result = OpenAIHistoryTrackingLLM(result)
result = OpenAITokenReplacingLLM(result)
return JsonParsingLLM(result)


def create_ollama_completion_llm(
client: OllamaClientType,
config: OllamaConfiguration,
cache: LLMCache | None = None,
limiter: LLMLimiter | None = None,
semaphore: asyncio.Semaphore | None = None,
on_invoke: LLMInvocationFn | None = None,
on_error: ErrorHandlerFn | None = None,
on_cache_hit: OnCacheActionFn | None = None,
on_cache_miss: OnCacheActionFn | None = None,
) -> CompletionLLM:
"""Create an OpenAI completion LLM."""
operation = "completion"
result = OllamaCompletionLLM(client, config)
result.on_error(on_error)
if limiter is not None or semaphore is not None:
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke)
if cache is not None:
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss)
return OpenAITokenReplacingLLM(result)


def create_ollama_embedding_llm(
client: OllamaClientType,
config: OllamaConfiguration,
cache: LLMCache | None = None,
limiter: LLMLimiter | None = None,
semaphore: asyncio.Semaphore | None = None,
on_invoke: LLMInvocationFn | None = None,
on_error: ErrorHandlerFn | None = None,
on_cache_hit: OnCacheActionFn | None = None,
on_cache_miss: OnCacheActionFn | None = None,
) -> EmbeddingLLM:
"""Create an OpenAI embeddings LLM."""
operation = "embedding"
result = OllamaEmbeddingsLLM(client, config)
result.on_error(on_error)
if limiter is not None or semaphore is not None:
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke)
if cache is not None:
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss)
return result


def _rate_limited(
delegate: LLM,
config: OllamaConfiguration,
operation: str,
limiter: LLMLimiter | None,
semaphore: asyncio.Semaphore | None,
on_invoke: LLMInvocationFn | None,
):
result = RateLimitingLLM(
delegate,
config,
operation,
RETRYABLE_ERRORS,
RATE_LIMIT_ERRORS,
limiter,
semaphore,
get_token_counter(config),
get_sleep_time_from_error,
)
result.on_invoke(on_invoke)
return result


def _cached(
delegate: LLM,
config: OllamaConfiguration,
operation: str,
cache: LLMCache,
on_cache_hit: OnCacheActionFn | None,
on_cache_miss: OnCacheActionFn | None,
):
cache_args = config.get_completion_cache_args()
result = CachingLLM(delegate, cache_args, operation, cache)
result.on_cache_hit(on_cache_hit)
result.on_cache_miss(on_cache_miss)
return result
38 changes: 38 additions & 0 deletions graphrag/llm/ollama/json_parsing_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""An LLM that unpacks cached JSON responses."""

from typing_extensions import Unpack

from graphrag.llm.types import (
LLM,
CompletionInput,
CompletionLLM,
CompletionOutput,
LLMInput,
LLMOutput,
)

from graphrag.llm.utils import try_parse_json_object


class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]):
"""An OpenAI History-Tracking LLM."""

_delegate: CompletionLLM

def __init__(self, delegate: CompletionLLM):
self._delegate = delegate

async def __call__(
self,
input: CompletionInput,
**kwargs: Unpack[LLMInput],
) -> LLMOutput[CompletionOutput]:
"""Call the LLM with the input and kwargs."""
result = await self._delegate(input, **kwargs)
if kwargs.get("json") and result.json is None and result.output is not None:
_, parsed_json = try_parse_json_object(result.output)
result.json = parsed_json
return result
90 changes: 90 additions & 0 deletions graphrag/llm/ollama/ollama_chat_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Chat-based language model."""

import logging

from typing_extensions import Unpack

from graphrag.llm.base import BaseLLM
from graphrag.llm.types import (
CompletionInput,
CompletionOutput,
LLMInput,
LLMOutput,
)
from graphrag.llm.utils import try_parse_json_object

from .ollama_configuration import OllamaConfiguration
from .types import OllamaClientType

log = logging.getLogger(__name__)

_MAX_GENERATION_RETRIES = 3
FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output"


class OllamaChatLLM(BaseLLM[CompletionInput, CompletionOutput]):
"""A Chat-based LLM."""

_client: OllamaClientType
_configuration: OllamaConfiguration

def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration):
self.client = client
self.configuration = configuration

async def _execute_llm(
self, input: CompletionInput, **kwargs: Unpack[LLMInput]
) -> CompletionOutput | None:
args = {
**self.configuration.get_chat_cache_args(),
}
history = kwargs.get("history") or []
messages = [
*history,
{"role": "user", "content": input},
]
completion = await self.client.chat(
messages=messages, **args
)
return completion["message"]["content"]

async def _invoke_json(
self,
input: CompletionInput,
**kwargs: Unpack[LLMInput],
) -> LLMOutput[CompletionOutput]:
"""Generate JSON output."""
name = kwargs.get("name") or "unknown"
is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True)

async def generate(
attempt: int | None = None,
) -> LLMOutput[CompletionOutput]:
call_name = name if attempt is None else f"{name}@{attempt}"
result = await self._invoke(input, **{**kwargs, "name": call_name})
print("output:\n", result)
output, json_output = try_parse_json_object(result.output or "")

return LLMOutput[CompletionOutput](
output=output,
json=json_output,
history=result.history,
)

def is_valid(x: dict | None) -> bool:
return x is not None and is_response_valid(x)

result = await generate()
retry = 0
while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES:
result = await generate(retry)
retry += 1

if is_valid(result.json):
return result

error_msg = f"{FAILED_TO_CREATE_JSON_ERROR} - Faulty JSON: {result.json!s}"
raise RuntimeError(error_msg)
44 changes: 44 additions & 0 deletions graphrag/llm/ollama/ollama_completion_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A text-completion based LLM."""

import logging

from typing_extensions import Unpack

from graphrag.llm.base import BaseLLM
from graphrag.llm.types import (
CompletionInput,
CompletionOutput,
LLMInput,
)
from graphrag.llm.utils import get_completion_llm_args

from .ollama_configuration import OllamaConfiguration
from .types import OllamaClientType


log = logging.getLogger(__name__)


class OllamaCompletionLLM(BaseLLM[CompletionInput, CompletionOutput]):
"""A text-completion based LLM."""

_client: OllamaClientType
_configuration: OllamaConfiguration

def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration):
self.client = client
self.configuration = configuration

async def _execute_llm(
self,
input: CompletionInput,
**kwargs: Unpack[LLMInput],
) -> CompletionOutput | None:
args = get_completion_llm_args(
kwargs.get("model_parameters"), self.configuration
)
completion = await self.client.generate(prompt=input, **args)
return completion["response"]
516 changes: 516 additions & 0 deletions graphrag/llm/ollama/ollama_configuration.py

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions graphrag/llm/ollama/ollama_embeddings_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The EmbeddingsLLM class."""

from typing_extensions import Unpack

from graphrag.llm.base import BaseLLM
from graphrag.llm.types import (
EmbeddingInput,
EmbeddingOutput,
LLMInput,
)

from .ollama_configuration import OllamaConfiguration
from .types import OllamaClientType


class OllamaEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]):
"""A text-embedding generator LLM."""

_client: OllamaClientType
_configuration: OllamaConfiguration

def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration):
self.client = client
self.configuration = configuration

async def _execute_llm(
self, input: EmbeddingInput, **kwargs: Unpack[LLMInput]
) -> EmbeddingOutput | None:
args = {
"model": self.configuration.model,
**(kwargs.get("model_parameters") or {}),
}
embedding = await self.client.embed(
input=input,
**args,
)
return embedding["embeddings"]
8 changes: 8 additions & 0 deletions graphrag/llm/ollama/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A base class for OpenAI-based LLMs."""

from ollama import AsyncClient, Client

OllamaClientType = AsyncClient | Client
15 changes: 7 additions & 8 deletions graphrag/llm/openai/factories.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,12 @@
LLMInvocationFn,
OnCacheActionFn,
)
from graphrag.llm.utils import (
RATE_LIMIT_ERRORS,
RETRYABLE_ERRORS,
get_sleep_time_from_error,
get_token_counter,
)

from .json_parsing_llm import JsonParsingLLM
from .openai_chat_llm import OpenAIChatLLM
@@ -25,13 +31,6 @@
from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM
from .openai_token_replacing_llm import OpenAITokenReplacingLLM
from .types import OpenAIClientTypes
from .utils import (
RATE_LIMIT_ERRORS,
RETRYABLE_ERRORS,
get_completion_cache_args,
get_sleep_time_from_error,
get_token_counter,
)


def create_openai_chat_llm(
@@ -133,7 +132,7 @@ def _cached(
on_cache_hit: OnCacheActionFn | None,
on_cache_miss: OnCacheActionFn | None,
):
cache_args = get_completion_cache_args(config)
cache_args = config.get_completion_cache_args()
result = CachingLLM(delegate, cache_args, operation, cache)
result.on_cache_hit(on_cache_hit)
result.on_cache_miss(on_cache_miss)
2 changes: 1 addition & 1 deletion graphrag/llm/openai/json_parsing_llm.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
LLMOutput,
)

from .utils import try_parse_json_object
from graphrag.llm.utils import try_parse_json_object


class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]):
8 changes: 4 additions & 4 deletions graphrag/llm/openai/openai_chat_llm.py
Original file line number Diff line number Diff line change
@@ -14,14 +14,14 @@
LLMInput,
LLMOutput,
)
from graphrag.llm.utils import (
get_completion_llm_args,
try_parse_json_object,
)

from ._prompts import JSON_CHECK_PROMPT
from .openai_configuration import OpenAIConfiguration
from .types import OpenAIClientTypes
from .utils import (
get_completion_llm_args,
try_parse_json_object,
)

log = logging.getLogger(__name__)

2 changes: 1 addition & 1 deletion graphrag/llm/openai/openai_completion_llm.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,10 @@
CompletionOutput,
LLMInput,
)
from graphrag.llm.utils import get_completion_llm_args

from .openai_configuration import OpenAIConfiguration
from .types import OpenAIClientTypes
from .utils import get_completion_llm_args

log = logging.getLogger(__name__)

36 changes: 21 additions & 15 deletions graphrag/llm/openai/openai_configuration.py
Original file line number Diff line number Diff line change
@@ -8,13 +8,7 @@
from typing import Any, cast

from graphrag.llm.types import LLMConfig


def _non_blank(value: str | None) -> str | None:
if value is None:
return None
stripped = value.strip()
return None if stripped == "" else value
from graphrag.llm.utils import non_blank


class OpenAIConfiguration(Hashable, LLMConfig):
@@ -141,34 +135,34 @@ def model(self) -> str:
@property
def deployment_name(self) -> str | None:
"""Deployment name property definition."""
return _non_blank(self._deployment_name)
return non_blank(self._deployment_name)

@property
def api_base(self) -> str | None:
"""API base property definition."""
result = _non_blank(self._api_base)
result = non_blank(self._api_base)
# Remove trailing slash
return result[:-1] if result and result.endswith("/") else result

@property
def api_version(self) -> str | None:
"""API version property definition."""
return _non_blank(self._api_version)
return non_blank(self._api_version)

@property
def audience(self) -> str | None:
"""API version property definition."""
return _non_blank(self._audience)
return non_blank(self._audience)

@property
def organization(self) -> str | None:
"""Organization property definition."""
return _non_blank(self._organization)
return non_blank(self._organization)

@property
def proxy(self) -> str | None:
"""Proxy property definition."""
return _non_blank(self._proxy)
return non_blank(self._proxy)

@property
def n(self) -> int | None:
@@ -203,7 +197,7 @@ def max_tokens(self) -> int | None:
@property
def response_format(self) -> str | None:
"""Response format property definition."""
return _non_blank(self._response_format)
return non_blank(self._response_format)

@property
def logit_bias(self) -> dict[str, float] | None:
@@ -253,7 +247,7 @@ def concurrent_requests(self) -> int | None:
@property
def encoding_model(self) -> str | None:
"""Encoding model property definition."""
return _non_blank(self._encoding_model)
return non_blank(self._encoding_model)

@property
def sleep_on_rate_limit_recommendation(self) -> bool | None:
@@ -269,6 +263,18 @@ def lookup(self, name: str, default_value: Any = None) -> Any:
"""Lookup method definition."""
return self._raw_config.get(name, default_value)

def get_completion_cache_args(self) -> dict:
"""Get the cache arguments for a completion LLM."""
return {
"model": self.model,
"temperature": self.temperature,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"n": self.n,
}

def __str__(self) -> str:
"""Str method definition."""
return json.dumps(self.raw_config, indent=4)
2 changes: 1 addition & 1 deletion graphrag/llm/openai/openai_token_replacing_llm.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
LLMOutput,
)

from .utils import perform_variable_replacements
from graphrag.llm.utils import perform_variable_replacements


class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]):
4 changes: 4 additions & 0 deletions graphrag/llm/types/llm_config.py
Original file line number Diff line number Diff line change
@@ -33,3 +33,7 @@ def tokens_per_minute(self) -> int | None:
def requests_per_minute(self) -> int | None:
"""Get the number of requests per minute."""
...

def get_completion_cache_args(self) -> dict:
"""Get the cache arguments for a completion LLM."""
...
35 changes: 18 additions & 17 deletions graphrag/llm/openai/utils.py → graphrag/llm/utils.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
RateLimitError,
)

from .openai_configuration import OpenAIConfiguration
from .types import LLMConfig

DEFAULT_ENCODING = "cl100k_base"

@@ -33,7 +33,7 @@
log = logging.getLogger(__name__)


def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]:
def get_token_counter(config: LLMConfig) -> Callable[[str], int]:
"""Get a function that counts the number of tokens in a string."""
model = config.encoding_model or "cl100k_base"
enc = _encoders.get(model)
@@ -66,25 +66,12 @@ def replace_all(input: str) -> str:
return result


def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict:
"""Get the cache arguments for a completion LLM."""
return {
"model": configuration.model,
"temperature": configuration.temperature,
"frequency_penalty": configuration.frequency_penalty,
"presence_penalty": configuration.presence_penalty,
"top_p": configuration.top_p,
"max_tokens": configuration.max_tokens,
"n": configuration.n,
}


def get_completion_llm_args(
parameters: dict | None, configuration: OpenAIConfiguration
parameters: dict | None, configuration: LLMConfig
) -> dict:
"""Get the arguments for a completion LLM."""
return {
**get_completion_cache_args(configuration),
**configuration.get_completion_cache_args(),
**(parameters or {}),
}

@@ -158,3 +145,17 @@ def get_sleep_time_from_error(e: Any) -> float:


_please_retry_after = "Please retry after "


def non_blank(value: str | None) -> str | None:
if value is None:
return None
stripped = value.strip()
return None if stripped == "" else value


def non_none_value_key(data: dict | None) -> dict:
"""Remove key from dict where value is None"""
if data is None:
return {}
return {k: v for k, v in data.items() if v is not None}
12 changes: 10 additions & 2 deletions graphrag/query/factories.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
GraphRagConfig,
LLMType,
)
from graphrag.llm import OllamaConfiguration
from graphrag.model import (
CommunityReport,
Covariate,
@@ -18,9 +19,12 @@
TextUnit,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.llm.ollama.chat_ollama import ChatOllama
from graphrag.query.llm.ollama.embeding import OllamaEmbedding
from graphrag.query.structured_search.global_search.community_context import (
GlobalCommunityContext,
)
@@ -32,8 +36,10 @@
from graphrag.vector_stores import BaseVectorStore


def get_llm(config: GraphRagConfig) -> ChatOpenAI:
def get_llm(config: GraphRagConfig) -> BaseLLM:
"""Get the LLM client."""
if config.llm.type in (LLMType.Ollama, LLMType.OllamaChat):
return ChatOllama(OllamaConfiguration(dict(config.llm)))
is_azure_client = (
config.llm.type == LLMType.AzureOpenAIChat
or config.llm.type == LLMType.AzureOpenAI
@@ -67,8 +73,10 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
)


def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
def get_text_embedder(config: GraphRagConfig) -> BaseTextEmbedding:
"""Get the LLM client for embeddings."""
if config.embeddings.llm.type == LLMType.OllamaEmbedding:
return OllamaEmbedding(OllamaConfiguration(dict(config.embeddings.llm)))
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = config.embeddings.llm.api_key or ""
llm_debug_info = {
10 changes: 10 additions & 0 deletions graphrag/query/llm/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from .chat_ollama import ChatOllama
from .embeding import OllamaEmbedding

__all__ = [
"ChatOllama",
"OllamaEmbedding",
]
88 changes: 88 additions & 0 deletions graphrag/query/llm/ollama/chat_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Chat-based Ollama LLM implementation."""
from typing import Any, AsyncGenerator, Generator

from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)

from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.llm import OllamaConfiguration, create_ollama_client
from graphrag.query.llm.base import BaseLLM


class ChatOllama(BaseLLM):
"""Wrapper for Ollama ChatCompletion models."""

def __init__(self, configuration: OllamaConfiguration):
self.configuration = configuration
self.sync_client = create_ollama_client(configuration, sync=True)
self.async_client = create_ollama_client(configuration)

def generate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate a response."""
response = self.sync_client.chat(
messages,
**self.configuration.get_chat_cache_args(),
)
return response["message"]["content"]

def stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
"""Generate a response with streaming."""

async def agenerate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate a response asynchronously."""
"""Generate text asynchronously."""
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.configuration.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(Exception), # type: ignore
)
async for attempt in retryer:
with attempt:
response = await self.async_client.chat(
messages=messages,
**{
**self.configuration.get_chat_cache_args(),
"stream": False,
}
)
return response["message"]["content"]
except Exception as e:
raise e

async def astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
"""Generate a response asynchronously with streaming."""
...

34 changes: 34 additions & 0 deletions graphrag/query/llm/ollama/embeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Ollama Embedding model implementation."""

from typing import Any

from graphrag.llm import OllamaConfiguration, create_ollama_client
from graphrag.query.llm.base import BaseTextEmbedding


class OllamaEmbedding(BaseTextEmbedding):
"""Wrapper for Ollama Embedding models."""

def __init__(self, configuration: OllamaConfiguration):
self.configuration = configuration
self.sync_client = create_ollama_client(configuration, sync=True)
self.async_client = create_ollama_client(configuration)

def embed(self, text: str, **kwargs: Any) -> list[float]:
"""Embed a text string."""
response = self.sync_client.embed(
input=text,
**self.configuration.get_embed_cache_args(),
)
return response["embeddings"][0]

async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""Embed a text string asynchronously."""
response = await self.async_client.embed(
input=text,
**self.configuration.get_embed_cache_args(),
)
return response["embeddings"][0]
2 changes: 1 addition & 1 deletion graphrag/query/structured_search/global_search/search.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
import tiktoken

from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback
from graphrag.llm.openai.utils import try_parse_json_object
from graphrag.llm.utils import try_parse_json_object
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -91,6 +91,7 @@ json-repair = "^0.30.0"

future = "^1.0.0" # Needed until graspologic fixes their dependency
typer = "^0.12.5"
ollama = "^0.3.3"

mkdocs-typer = "^0.0.3"
[tool.poetry.group.dev.dependencies]