From c2f6ce408f75475f6be05046b0ff94f884d08002 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Thu, 26 Oct 2023 18:12:16 +0000 Subject: [PATCH] Add invoke/batch test cases and pass stop to batch retry --- libs/langchain/langchain/llms/fireworks.py | 25 ++++---- .../chat_models/test_fireworks.py | 33 +++++++++- .../integration_tests/llms/test_fireworks.py | 61 ++++++++++++++++++- 3 files changed, 104 insertions(+), 15 deletions(-) diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 2ef559fa13345..867c916c1a586 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -88,11 +88,16 @@ def _generate( "model": self.model, **self.model_kwargs, } - sub_prompts = self.get_batch_prompts(params, prompts, stop) + sub_prompts = self.get_batch_prompts(prompts) choices = [] for _prompts in sub_prompts: response = completion_with_retry_batching( - self, self.use_retry, prompt=_prompts, run_manager=run_manager, **params + self, + self.use_retry, + prompt=_prompts, + run_manager=run_manager, + stop=stop, + **params, ) choices.extend(response) @@ -110,11 +115,16 @@ async def _agenerate( "model": self.model, **self.model_kwargs, } - sub_prompts = self.get_batch_prompts(params, prompts, stop) + sub_prompts = self.get_batch_prompts(prompts) choices = [] for _prompts in sub_prompts: response = await acompletion_with_retry_batching( - self, self.use_retry, prompt=_prompts, run_manager=run_manager, **params + self, + self.use_retry, + prompt=_prompts, + run_manager=run_manager, + stop=stop, + **params, ) choices.extend(response) @@ -122,16 +132,9 @@ async def _agenerate( def get_batch_prompts( self, - params: Dict[str, Any], prompts: List[str], - stop: Optional[List[str]] = None, ) -> List[List[str]]: """Get the sub prompts for llm call.""" - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - sub_prompts = [ prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index 782ea80c15594..9553c9f35c2da 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -78,6 +78,23 @@ def test_chat_fireworks_llm_output_stop_words() -> None: assert llm_result.generations[0][0].text[-1] == "," +def test_fireworks_invoke() -> None: + """Tests chat completion with invoke""" + chat = ChatFireworks() + result = chat.invoke("How is the weather in New York today?", stop=[","]) + assert isinstance(result.content, str) + assert result.content[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_ainvoke() -> None: + """Tests chat completion with invoke""" + chat = ChatFireworks() + result = await chat.ainvoke("How is the weather in New York today?", stop=[","]) + assert isinstance(result.content, str) + assert result.content[-1] == "," + + def test_fireworks_batch() -> None: """Test batch tokens from ChatFireworks.""" chat = ChatFireworks() @@ -91,15 +108,18 @@ def test_fireworks_batch() -> None: "What is the weather in Redwood City, CA today", ], config={"max_concurrency": 5}, + stop=[","], ) for token in result: assert isinstance(token.content, str) + assert token.content[-1] == "," +@pytest.mark.asyncio async def test_fireworks_abatch() -> None: """Test batch tokens from ChatFireworks.""" - llm = ChatFireworks() - result = await llm.abatch( + chat = ChatFireworks() + result = await chat.abatch( [ "What is the weather in Redwood City, CA today", "What is the weather in Redwood City, CA today", @@ -109,9 +129,11 @@ async def test_fireworks_abatch() -> None: "What is the weather in Redwood City, CA today", ], config={"max_concurrency": 5}, + stop=[","], ) for token in result: assert isinstance(token.content, str) + assert token.content[-1] == "," def test_fireworks_streaming() -> None: @@ -154,5 +176,10 @@ async def test_fireworks_astream() -> None: """Test streaming tokens from Fireworks.""" llm = ChatFireworks() - async for token in llm.astream("Who's the best quarterback in the NFL?"): + last_token = "" + async for token in llm.astream( + "Who's the best quarterback in the NFL?", stop=[","] + ): + last_token = token.content assert isinstance(token.content, str) + assert last_token[-1] == "," diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index 63ec6a4282d2a..dfddb1224e83f 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None: assert llm.model == "foo" +def test_fireworks_invoke() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = llm.invoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_ainvoke() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = await llm.ainvoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +def test_fireworks_batch() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = llm.batch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_abatch() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = await llm.abatch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + def test_fireworks_multiple_prompts() -> None: """Test completion with multiple prompts.""" llm = Fireworks() @@ -87,8 +141,13 @@ async def test_fireworks_streaming_async() -> None: """Test stream completion.""" llm = Fireworks() - async for token in llm.astream("Who's the best quarterback in the NFL?"): + last_token = "" + async for token in llm.astream( + "Who's the best quarterback in the NFL?", stop=[","] + ): + last_token = token assert isinstance(token, str) + assert last_token[-1] == "," @pytest.mark.asyncio