-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ollama #1326
Open
Smpests
wants to merge
8
commits into
microsoft:main
Choose a base branch
from
Smpests:feature/ollama-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,271
−53
Open
Support ollama #1326
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b2736a9
ollama support.
Smpests 5c9dcdb
remove useless print code
Smpests a98ae6a
resolve generate community report error
Smpests 82d3a07
ollama support.
Smpests f57f4a3
remove useless print code
Smpests 7452efc
resolve generate community report error
Smpests 970d324
resolve coflict
Smpests 84be1e5
support search by ollama
Smpests File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Ollama LLM implementations.""" | ||
|
||
from .create_ollama_client import create_ollama_client | ||
from .factories import ( | ||
create_ollama_chat_llm, | ||
create_ollama_completion_llm, | ||
create_ollama_embedding_llm, | ||
) | ||
from .ollama_chat_llm import OllamaChatLLM | ||
from .ollama_completion_llm import OllamaCompletionLLM | ||
from .ollama_configuration import OllamaConfiguration | ||
from .ollama_embeddings_llm import OllamaEmbeddingsLLM | ||
from .types import OllamaClientType | ||
|
||
|
||
__all__ = [ | ||
"OllamaChatLLM", | ||
"OllamaClientType", | ||
"OllamaCompletionLLM", | ||
"OllamaConfiguration", | ||
"OllamaEmbeddingsLLM", | ||
"create_ollama_chat_llm", | ||
"create_ollama_client", | ||
"create_ollama_completion_llm", | ||
"create_ollama_embedding_llm", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Create OpenAI client instance.""" | ||
|
||
import logging | ||
from functools import cache | ||
|
||
from ollama import AsyncClient, Client | ||
|
||
from .ollama_configuration import OllamaConfiguration | ||
from .types import OllamaClientType | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" | ||
|
||
|
||
@cache | ||
def create_ollama_client( | ||
configuration: OllamaConfiguration, | ||
sync: bool = False, | ||
) -> OllamaClientType: | ||
"""Create a new Ollama client instance.""" | ||
|
||
log.info("Creating OpenAI client base_url=%s", configuration.api_base) | ||
if sync: | ||
return Client( | ||
host=configuration.api_base, | ||
timeout=configuration.request_timeout or 180.0, | ||
) | ||
return AsyncClient( | ||
host=configuration.api_base, | ||
# Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here | ||
timeout=configuration.request_timeout or 180.0, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Factory functions for creating OpenAI LLMs.""" | ||
|
||
import asyncio | ||
|
||
from graphrag.llm.base import CachingLLM, RateLimitingLLM | ||
from graphrag.llm.limiting import LLMLimiter | ||
from graphrag.llm.types import ( | ||
LLM, | ||
CompletionLLM, | ||
EmbeddingLLM, | ||
ErrorHandlerFn, | ||
LLMCache, | ||
LLMInvocationFn, | ||
OnCacheActionFn, | ||
) | ||
from graphrag.llm.utils import ( | ||
RATE_LIMIT_ERRORS, | ||
RETRYABLE_ERRORS, | ||
get_sleep_time_from_error, | ||
get_token_counter, | ||
) | ||
from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM | ||
from graphrag.llm.openai.openai_token_replacing_llm import OpenAITokenReplacingLLM | ||
|
||
from .json_parsing_llm import JsonParsingLLM | ||
from .ollama_chat_llm import OllamaChatLLM | ||
from .ollama_completion_llm import OllamaCompletionLLM | ||
from .ollama_configuration import OllamaConfiguration | ||
from .ollama_embeddings_llm import OllamaEmbeddingsLLM | ||
from .types import OllamaClientType | ||
|
||
|
||
def create_ollama_chat_llm( | ||
client: OllamaClientType, | ||
config: OllamaConfiguration, | ||
cache: LLMCache | None = None, | ||
limiter: LLMLimiter | None = None, | ||
semaphore: asyncio.Semaphore | None = None, | ||
on_invoke: LLMInvocationFn | None = None, | ||
on_error: ErrorHandlerFn | None = None, | ||
on_cache_hit: OnCacheActionFn | None = None, | ||
on_cache_miss: OnCacheActionFn | None = None, | ||
) -> CompletionLLM: | ||
"""Create an OpenAI chat LLM.""" | ||
operation = "chat" | ||
result = OllamaChatLLM(client, config) | ||
result.on_error(on_error) | ||
if limiter is not None or semaphore is not None: | ||
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) | ||
if cache is not None: | ||
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) | ||
result = OpenAIHistoryTrackingLLM(result) | ||
result = OpenAITokenReplacingLLM(result) | ||
return JsonParsingLLM(result) | ||
|
||
|
||
def create_ollama_completion_llm( | ||
client: OllamaClientType, | ||
config: OllamaConfiguration, | ||
cache: LLMCache | None = None, | ||
limiter: LLMLimiter | None = None, | ||
semaphore: asyncio.Semaphore | None = None, | ||
on_invoke: LLMInvocationFn | None = None, | ||
on_error: ErrorHandlerFn | None = None, | ||
on_cache_hit: OnCacheActionFn | None = None, | ||
on_cache_miss: OnCacheActionFn | None = None, | ||
) -> CompletionLLM: | ||
"""Create an OpenAI completion LLM.""" | ||
operation = "completion" | ||
result = OllamaCompletionLLM(client, config) | ||
result.on_error(on_error) | ||
if limiter is not None or semaphore is not None: | ||
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) | ||
if cache is not None: | ||
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) | ||
return OpenAITokenReplacingLLM(result) | ||
|
||
|
||
def create_ollama_embedding_llm( | ||
client: OllamaClientType, | ||
config: OllamaConfiguration, | ||
cache: LLMCache | None = None, | ||
limiter: LLMLimiter | None = None, | ||
semaphore: asyncio.Semaphore | None = None, | ||
on_invoke: LLMInvocationFn | None = None, | ||
on_error: ErrorHandlerFn | None = None, | ||
on_cache_hit: OnCacheActionFn | None = None, | ||
on_cache_miss: OnCacheActionFn | None = None, | ||
) -> EmbeddingLLM: | ||
"""Create an OpenAI embeddings LLM.""" | ||
operation = "embedding" | ||
result = OllamaEmbeddingsLLM(client, config) | ||
result.on_error(on_error) | ||
if limiter is not None or semaphore is not None: | ||
result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) | ||
if cache is not None: | ||
result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) | ||
return result | ||
|
||
|
||
def _rate_limited( | ||
delegate: LLM, | ||
config: OllamaConfiguration, | ||
operation: str, | ||
limiter: LLMLimiter | None, | ||
semaphore: asyncio.Semaphore | None, | ||
on_invoke: LLMInvocationFn | None, | ||
): | ||
result = RateLimitingLLM( | ||
delegate, | ||
config, | ||
operation, | ||
RETRYABLE_ERRORS, | ||
RATE_LIMIT_ERRORS, | ||
limiter, | ||
semaphore, | ||
get_token_counter(config), | ||
get_sleep_time_from_error, | ||
) | ||
result.on_invoke(on_invoke) | ||
return result | ||
|
||
|
||
def _cached( | ||
delegate: LLM, | ||
config: OllamaConfiguration, | ||
operation: str, | ||
cache: LLMCache, | ||
on_cache_hit: OnCacheActionFn | None, | ||
on_cache_miss: OnCacheActionFn | None, | ||
): | ||
cache_args = config.get_completion_cache_args() | ||
result = CachingLLM(delegate, cache_args, operation, cache) | ||
result.on_cache_hit(on_cache_hit) | ||
result.on_cache_miss(on_cache_miss) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""An LLM that unpacks cached JSON responses.""" | ||
|
||
from typing_extensions import Unpack | ||
|
||
from graphrag.llm.types import ( | ||
LLM, | ||
CompletionInput, | ||
CompletionLLM, | ||
CompletionOutput, | ||
LLMInput, | ||
LLMOutput, | ||
) | ||
|
||
from graphrag.llm.utils import try_parse_json_object | ||
|
||
|
||
class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): | ||
"""An OpenAI History-Tracking LLM.""" | ||
|
||
_delegate: CompletionLLM | ||
|
||
def __init__(self, delegate: CompletionLLM): | ||
self._delegate = delegate | ||
|
||
async def __call__( | ||
self, | ||
input: CompletionInput, | ||
**kwargs: Unpack[LLMInput], | ||
) -> LLMOutput[CompletionOutput]: | ||
"""Call the LLM with the input and kwargs.""" | ||
result = await self._delegate(input, **kwargs) | ||
if kwargs.get("json") and result.json is None and result.output is not None: | ||
_, parsed_json = try_parse_json_object(result.output) | ||
result.json = parsed_json | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""The Chat-based language model.""" | ||
|
||
import logging | ||
|
||
from typing_extensions import Unpack | ||
|
||
from graphrag.llm.base import BaseLLM | ||
from graphrag.llm.types import ( | ||
CompletionInput, | ||
CompletionOutput, | ||
LLMInput, | ||
LLMOutput, | ||
) | ||
from graphrag.llm.utils import try_parse_json_object | ||
|
||
from .ollama_configuration import OllamaConfiguration | ||
from .types import OllamaClientType | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
_MAX_GENERATION_RETRIES = 3 | ||
FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" | ||
|
||
|
||
class OllamaChatLLM(BaseLLM[CompletionInput, CompletionOutput]): | ||
"""A Chat-based LLM.""" | ||
|
||
_client: OllamaClientType | ||
_configuration: OllamaConfiguration | ||
|
||
def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): | ||
self.client = client | ||
self.configuration = configuration | ||
|
||
async def _execute_llm( | ||
self, input: CompletionInput, **kwargs: Unpack[LLMInput] | ||
) -> CompletionOutput | None: | ||
args = { | ||
**self.configuration.get_chat_cache_args(), | ||
} | ||
history = kwargs.get("history") or [] | ||
messages = [ | ||
*history, | ||
{"role": "user", "content": input}, | ||
] | ||
completion = await self.client.chat( | ||
messages=messages, **args | ||
) | ||
return completion["message"]["content"] | ||
|
||
async def _invoke_json( | ||
self, | ||
input: CompletionInput, | ||
**kwargs: Unpack[LLMInput], | ||
) -> LLMOutput[CompletionOutput]: | ||
"""Generate JSON output.""" | ||
name = kwargs.get("name") or "unknown" | ||
is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) | ||
|
||
async def generate( | ||
attempt: int | None = None, | ||
) -> LLMOutput[CompletionOutput]: | ||
call_name = name if attempt is None else f"{name}@{attempt}" | ||
result = await self._invoke(input, **{**kwargs, "name": call_name}) | ||
print("output:\n", result) | ||
output, json_output = try_parse_json_object(result.output or "") | ||
|
||
return LLMOutput[CompletionOutput]( | ||
output=output, | ||
json=json_output, | ||
history=result.history, | ||
) | ||
|
||
def is_valid(x: dict | None) -> bool: | ||
return x is not None and is_response_valid(x) | ||
|
||
result = await generate() | ||
retry = 0 | ||
while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: | ||
result = await generate(retry) | ||
retry += 1 | ||
|
||
if is_valid(result.json): | ||
return result | ||
|
||
error_msg = f"{FAILED_TO_CREATE_JSON_ERROR} - Faulty JSON: {result.json!s}" | ||
raise RuntimeError(error_msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""A text-completion based LLM.""" | ||
|
||
import logging | ||
|
||
from typing_extensions import Unpack | ||
|
||
from graphrag.llm.base import BaseLLM | ||
from graphrag.llm.types import ( | ||
CompletionInput, | ||
CompletionOutput, | ||
LLMInput, | ||
) | ||
from graphrag.llm.utils import get_completion_llm_args | ||
|
||
from .ollama_configuration import OllamaConfiguration | ||
from .types import OllamaClientType | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class OllamaCompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): | ||
"""A text-completion based LLM.""" | ||
|
||
_client: OllamaClientType | ||
_configuration: OllamaConfiguration | ||
|
||
def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): | ||
self.client = client | ||
self.configuration = configuration | ||
|
||
async def _execute_llm( | ||
self, | ||
input: CompletionInput, | ||
**kwargs: Unpack[LLMInput], | ||
) -> CompletionOutput | None: | ||
args = get_completion_llm_args( | ||
kwargs.get("model_parameters"), self.configuration | ||
) | ||
completion = await self.client.generate(prompt=input, **args) | ||
return completion["response"] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""The EmbeddingsLLM class.""" | ||
|
||
from typing_extensions import Unpack | ||
|
||
from graphrag.llm.base import BaseLLM | ||
from graphrag.llm.types import ( | ||
EmbeddingInput, | ||
EmbeddingOutput, | ||
LLMInput, | ||
) | ||
|
||
from .ollama_configuration import OllamaConfiguration | ||
from .types import OllamaClientType | ||
|
||
|
||
class OllamaEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): | ||
"""A text-embedding generator LLM.""" | ||
|
||
_client: OllamaClientType | ||
_configuration: OllamaConfiguration | ||
|
||
def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): | ||
self.client = client | ||
self.configuration = configuration | ||
|
||
async def _execute_llm( | ||
self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] | ||
) -> EmbeddingOutput | None: | ||
args = { | ||
"model": self.configuration.model, | ||
**(kwargs.get("model_parameters") or {}), | ||
} | ||
embedding = await self.client.embed( | ||
input=input, | ||
**args, | ||
) | ||
return embedding["embeddings"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""A base class for OpenAI-based LLMs.""" | ||
|
||
from ollama import AsyncClient, Client | ||
|
||
OllamaClientType = AsyncClient | Client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
from .chat_ollama import ChatOllama | ||
from .embeding import OllamaEmbedding | ||
|
||
__all__ = [ | ||
"ChatOllama", | ||
"OllamaEmbedding", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Chat-based Ollama LLM implementation.""" | ||
from typing import Any, AsyncGenerator, Generator | ||
|
||
from tenacity import ( | ||
AsyncRetrying, | ||
RetryError, | ||
Retrying, | ||
retry_if_exception_type, | ||
stop_after_attempt, | ||
wait_exponential_jitter, | ||
) | ||
|
||
from graphrag.callbacks.llm_callbacks import BaseLLMCallback | ||
from graphrag.llm import OllamaConfiguration, create_ollama_client | ||
from graphrag.query.llm.base import BaseLLM | ||
|
||
|
||
class ChatOllama(BaseLLM): | ||
"""Wrapper for Ollama ChatCompletion models.""" | ||
|
||
def __init__(self, configuration: OllamaConfiguration): | ||
self.configuration = configuration | ||
self.sync_client = create_ollama_client(configuration, sync=True) | ||
self.async_client = create_ollama_client(configuration) | ||
|
||
def generate( | ||
self, | ||
messages: str | list[Any], | ||
streaming: bool = True, | ||
callbacks: list[BaseLLMCallback] | None = None, | ||
**kwargs: Any, | ||
) -> str: | ||
"""Generate a response.""" | ||
response = self.sync_client.chat( | ||
messages, | ||
**self.configuration.get_chat_cache_args(), | ||
) | ||
return response["message"]["content"] | ||
|
||
def stream_generate( | ||
self, | ||
messages: str | list[Any], | ||
callbacks: list[BaseLLMCallback] | None = None, | ||
**kwargs: Any, | ||
) -> Generator[str, None, None]: | ||
"""Generate a response with streaming.""" | ||
|
||
async def agenerate( | ||
self, | ||
messages: str | list[Any], | ||
streaming: bool = True, | ||
callbacks: list[BaseLLMCallback] | None = None, | ||
**kwargs: Any, | ||
) -> str: | ||
"""Generate a response asynchronously.""" | ||
"""Generate text asynchronously.""" | ||
try: | ||
retryer = AsyncRetrying( | ||
stop=stop_after_attempt(self.configuration.max_retries), | ||
wait=wait_exponential_jitter(max=10), | ||
reraise=True, | ||
retry=retry_if_exception_type(Exception), # type: ignore | ||
) | ||
async for attempt in retryer: | ||
with attempt: | ||
response = await self.async_client.chat( | ||
messages=messages, | ||
**{ | ||
**self.configuration.get_chat_cache_args(), | ||
"stream": False, | ||
} | ||
) | ||
return response["message"]["content"] | ||
except Exception as e: | ||
raise e | ||
|
||
async def astream_generate( | ||
self, | ||
messages: str | list[Any], | ||
callbacks: list[BaseLLMCallback] | None = None, | ||
**kwargs: Any, | ||
) -> AsyncGenerator[str, None]: | ||
"""Generate a response asynchronously with streaming.""" | ||
... | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Ollama Embedding model implementation.""" | ||
|
||
from typing import Any | ||
|
||
from graphrag.llm import OllamaConfiguration, create_ollama_client | ||
from graphrag.query.llm.base import BaseTextEmbedding | ||
|
||
|
||
class OllamaEmbedding(BaseTextEmbedding): | ||
"""Wrapper for Ollama Embedding models.""" | ||
|
||
def __init__(self, configuration: OllamaConfiguration): | ||
self.configuration = configuration | ||
self.sync_client = create_ollama_client(configuration, sync=True) | ||
self.async_client = create_ollama_client(configuration) | ||
|
||
def embed(self, text: str, **kwargs: Any) -> list[float]: | ||
"""Embed a text string.""" | ||
response = self.sync_client.embed( | ||
input=text, | ||
**self.configuration.get_embed_cache_args(), | ||
) | ||
return response["embeddings"][0] | ||
|
||
async def aembed(self, text: str, **kwargs: Any) -> list[float]: | ||
"""Embed a text string asynchronously.""" | ||
response = await self.async_client.embed( | ||
input=text, | ||
**self.configuration.get_embed_cache_args(), | ||
) | ||
return response["embeddings"][0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fyi, "Creating OpenAI client base_url=%s" should be "Creating Ollama client base_url=%s"