From f6517d627ec7e4d371963ecc6a38b67d36195e87 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Fri, 6 Oct 2023 10:11:59 -0700 Subject: [PATCH] Support Fireworks batching (#6) * Support batch --- libs/langchain/langchain/llms/fireworks.py | 185 +++++++++++++++++- .../integration_tests/llms/test_fireworks.py | 21 ++ 2 files changed, 203 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 6922b2a6e92b7..a2913310e5a98 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,4 +1,16 @@ -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Set, + Union, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -7,7 +19,7 @@ from langchain.llms.base import LLM, create_base_retry_decorator from langchain.pydantic_v1 import Field, root_validator from langchain.schema.language_model import LanguageModelInput -from langchain.schema.output import GenerationChunk +from langchain.schema.output import Generation, GenerationChunk, LLMResult from langchain.schema.runnable.config import RunnableConfig from langchain.utils.env import get_from_dict_or_env @@ -38,6 +50,7 @@ class Fireworks(LLM): ) fireworks_api_key: Optional[str] = None max_retries: int = 20 + batch_size: int = 20 @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -95,6 +108,87 @@ async def _acall( return response.choices[0].text + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Call out to Fireworks endpoint with k unique prompts. + Args: + prompts: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + The full LLM output. + """ + params = { + "model": self.model, + **self.model_kwargs, + } + sub_prompts = self.get_batch_prompts(params, prompts, stop) + choices = [] + for _prompts in sub_prompts: + response = completion_with_retry_batching(self, prompt=_prompts, **params) + choices.extend(response) + + return self.create_llm_result(choices, prompts) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Call out to Fireworks endpoint async with k unique prompts.""" + params = { + "model": self.model, + **self.model_kwargs, + } + sub_prompts = self.get_batch_prompts(params, prompts, stop) + choices = [] + for _prompts in sub_prompts: + response = await acompletion_with_retry_batching( + self, prompt=_prompts, **params + ) + choices.extend(response) + + return self.create_llm_result(choices, prompts) + + def get_batch_prompts( + self, + params: Dict[str, Any], + prompts: List[str], + stop: Optional[List[str]] = None, + ) -> List[List[str]]: + """Get the sub prompts for llm call.""" + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + + sub_prompts = [ + prompts[i : i + self.batch_size] + for i in range(0, len(prompts), self.batch_size) + ] + return sub_prompts + + def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult: + """Create the LLMResult from the choices and prompts.""" + generations = [] + for i, _ in enumerate(prompts): + sub_choices = choices[i : (i + 1)] + generations.append( + [ + Generation( + text=choice.__dict__["choices"][0].text, + ) + for choice in sub_choices + ] + ) + llm_output = {"model": self.model} + return LLMResult(generations=generations, llm_output=llm_output) + def _stream( self, prompt: str, @@ -108,7 +202,7 @@ def _stream( "stream": True, **self.model_kwargs, } - for stream_resp in completion_with_retry( + for stream_resp in completion_with_retry_streaming( self, run_manager=run_manager, stop=stop, **params ): chunk = _stream_response_to_generation_chunk(stream_resp) @@ -210,6 +304,91 @@ async def _completion_with_retry(**kwargs: Any) -> Any: return await _completion_with_retry(**kwargs) +def completion_with_retry_batching( + llm: Fireworks, + *, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + import fireworks.client + + prompt = kwargs["prompt"] + del kwargs["prompt"] + + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(prompt) -> Any: + return fireworks.client.Completion.create(**kwargs, prompt=prompt) + + def batch_sync_run(): + with ThreadPoolExecutor() as executor: + results = list(executor.map(_completion_with_retry, prompt)) + return results + + return batch_sync_run() + + +async def acompletion_with_retry_batching( + llm: Fireworks, + *, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + import fireworks.client + + prompt = kwargs["prompt"] + del kwargs["prompt"] + + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(prompt) -> Any: + return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt) + + def run_coroutine_in_new_loop(coroutine_func, *args, **kwargs): + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete(coroutine_func(*args, **kwargs)) + finally: + new_loop.close() + + async def batch_sync_run(): + with ThreadPoolExecutor() as executor: + results = list( + executor.map( + run_coroutine_in_new_loop, + [_completion_with_retry] * len(prompt), + prompt, + ) + ) + return results + + return await batch_sync_run() + + +def completion_with_retry_streaming( + llm: Fireworks, + *, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + import fireworks.client + + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return fireworks.client.Completion.create( + **kwargs, + ) + + return _completion_with_retry(**kwargs) + + async def acompletion_with_retry_streaming( llm: Fireworks, *, diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index cbc7473665a4b..2c1585769b4ad 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -86,3 +86,24 @@ async def test_fireworks_multiple_prompts_async_agenerate() -> None: assert isinstance(output, LLMResult) assert isinstance(output.generations, list) assert len(output.generations) == 2 + + +def test_fireworks_batch() -> None: + """Test streaming tokens from Fireworks.""" + llm = Fireworks() + + result = llm.batch(["How is the weather in New York today?", "I'm pickle rick"]) + for token in result: + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_fireworks_abatch() -> None: + """Test streaming tokens from Fireworks.""" + llm = Fireworks() + + result = await llm.abatch( + ["How is the weather in New York today?", "I'm not Pickle Rick"] + ) + for token in result: + assert isinstance(token, str)