Skip to content

Commit

Permalink
chore: stream function calling tests (#4)
Browse files Browse the repository at this point in the history
* feat: stream function calling

* chore: update stream auto tool_choice test
  • Loading branch information
samuelint authored Jul 23, 2024
1 parent 1ed89d8 commit 7c454d4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/test_functional/models_configuration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
from llama_cpp import Llama
from llama_cpp.server.app import LlamaProxy
Expand Down Expand Up @@ -50,6 +51,8 @@ def create_llama(params) -> Llama:
return Llama(
model_path=local_path,
n_gpu_layers=n_gpu_layers,
offload_kqv=True, # Equivalent of f16_kv=True
n_threads=multiprocessing.cpu_count() - 1,
chat_format="chatml-function-calling",
)

Expand Down
64 changes: 64 additions & 0 deletions tests/test_functional/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,67 @@ def magic_number_tool(input: int) -> int:
)

assert result.tool_calls[0]["name"] == "magic_number_tool"


class TestFunctionCallingWithStream:
@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request.param)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama, temperature=0, streaming=True)

def test_force_function_calling(self, instance: LlamaChatModel):
@tool
def magic_number_tool(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2

llm_with_tool = instance.bind_tools(
[magic_number_tool], tool_choice="magic_number_tool"
)

stream = llm_with_tool.stream(
[
HumanMessage(content="What is the magic mumber of 2?"),
]
)

tool_call_chunks = []
for chunk in stream:
if len(chunk.tool_call_chunks) > 0:
tool_call_chunks.extend(chunk.tool_call_chunks)

assert len(tool_call_chunks) > 0
assert tool_call_chunks[0]["name"] == "magic_number_tool"

@pytest.mark.skip(
reason="""\
Stream + auto tool choice not supported yet. \
https://github.com/abetlen/llama-cpp-python/discussions/1615\
"""
)
def test_auto_function_calling(self, instance: LlamaChatModel):
@tool
def magic_number_tool(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2

llm_with_tool = instance.bind_tools([magic_number_tool], tool_choice="auto")

stream = llm_with_tool.stream(
[
HumanMessage(content="What is the magic mumber of 2?"),
]
)

tool_call_chunks = []
for chunk in stream:
if len(chunk.tool_call_chunks) > 0:
tool_call_chunks.extend(chunk.tool_call_chunks)

assert len(tool_call_chunks) > 0
assert tool_call_chunks[0]["name"] == "magic_number_tool"

0 comments on commit 7c454d4

Please sign in to comment.