Skip to content

Commit

Permalink
Support Fireworks batching (#6)
Browse files Browse the repository at this point in the history
* Support batch
  • Loading branch information
ZixinYang authored Oct 6, 2023
1 parent 5e2d504 commit f6517d6
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 3 deletions.
185 changes: 182 additions & 3 deletions libs/langchain/langchain/llms/fireworks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Union,
)

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -7,7 +19,7 @@
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.output import GenerationChunk
from langchain.schema.output import Generation, GenerationChunk, LLMResult
from langchain.schema.runnable.config import RunnableConfig
from langchain.utils.env import get_from_dict_or_env

Expand Down Expand Up @@ -38,6 +50,7 @@ class Fireworks(LLM):
)
fireworks_api_key: Optional[str] = None
max_retries: int = 20
batch_size: int = 20

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -95,6 +108,87 @@ async def _acall(

return response.choices[0].text

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
for _prompts in sub_prompts:
response = completion_with_retry_batching(self, prompt=_prompts, **params)
choices.extend(response)

return self.create_llm_result(choices, prompts)

async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
for _prompts in sub_prompts:
response = await acompletion_with_retry_batching(
self, prompt=_prompts, **params
)
choices.extend(response)

return self.create_llm_result(choices, prompts)

def get_batch_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")

sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts

def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice.__dict__["choices"][0].text,
)
for choice in sub_choices
]
)
llm_output = {"model": self.model}
return LLMResult(generations=generations, llm_output=llm_output)

def _stream(
self,
prompt: str,
Expand All @@ -108,7 +202,7 @@ def _stream(
"stream": True,
**self.model_kwargs,
}
for stream_resp in completion_with_retry(
for stream_resp in completion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
Expand Down Expand Up @@ -210,6 +304,91 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs)


def completion_with_retry_batching(
llm: Fireworks,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client

prompt = kwargs["prompt"]
del kwargs["prompt"]

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
def _completion_with_retry(prompt) -> Any:
return fireworks.client.Completion.create(**kwargs, prompt=prompt)

def batch_sync_run():
with ThreadPoolExecutor() as executor:
results = list(executor.map(_completion_with_retry, prompt))
return results

return batch_sync_run()


async def acompletion_with_retry_batching(
llm: Fireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client

prompt = kwargs["prompt"]
del kwargs["prompt"]

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
async def _completion_with_retry(prompt) -> Any:
return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)

def run_coroutine_in_new_loop(coroutine_func, *args, **kwargs):
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
finally:
new_loop.close()

async def batch_sync_run():
with ThreadPoolExecutor() as executor:
results = list(
executor.map(
run_coroutine_in_new_loop,
[_completion_with_retry] * len(prompt),
prompt,
)
)
return results

return await batch_sync_run()


def completion_with_retry_streaming(
llm: Fireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
import fireworks.client

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
)

return _completion_with_retry(**kwargs)


async def acompletion_with_retry_streaming(
llm: Fireworks,
*,
Expand Down
21 changes: 21 additions & 0 deletions libs/langchain/tests/integration_tests/llms/test_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,24 @@ async def test_fireworks_multiple_prompts_async_agenerate() -> None:
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2


def test_fireworks_batch() -> None:
"""Test streaming tokens from Fireworks."""
llm = Fireworks()

result = llm.batch(["How is the weather in New York today?", "I'm pickle rick"])
for token in result:
assert isinstance(token, str)


@pytest.mark.asyncio
async def test_fireworks_abatch() -> None:
"""Test streaming tokens from Fireworks."""
llm = Fireworks()

result = await llm.abatch(
["How is the weather in New York today?", "I'm not Pickle Rick"]
)
for token in result:
assert isinstance(token, str)

0 comments on commit f6517d6

Please sign in to comment.