-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Streamlit app tests with mocked client (#102)
* Add Streamlit app tests with mocked client * Minor cleanups
- Loading branch information
1 parent
dfd30f9
commit 8bfaa9a
Showing
2 changed files
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def mock_agent_client(mock_env): | ||
"""Fixture for creating a mock AgentClient with a clean environment.""" | ||
|
||
with patch("client.AgentClient") as mock_agent_client: | ||
mock_agent_client_instance = mock_agent_client.return_value | ||
yield mock_agent_client_instance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from collections.abc import AsyncGenerator | ||
from unittest.mock import AsyncMock, Mock | ||
|
||
import pytest | ||
from streamlit.testing.v1 import AppTest | ||
|
||
from schema import ChatHistory, ChatMessage | ||
from schema.models import AnthropicModelName | ||
|
||
|
||
def test_app_simple_non_streaming(mock_agent_client): | ||
"""Test the full app - happy path""" | ||
at = AppTest.from_file("../../src/streamlit_app.py").run() | ||
|
||
WELCOME_START = "Hello! I'm an AI-powered research assistant" | ||
PROMPT = "Know any jokes?" | ||
RESPONSE = "Sure! Here's a joke:" | ||
|
||
mock_agent_client.ainvoke = AsyncMock( | ||
return_value=ChatMessage(type="ai", content=RESPONSE), | ||
) | ||
|
||
assert at.chat_message[0].avatar == "assistant" | ||
assert at.chat_message[0].markdown[0].value.startswith(WELCOME_START) | ||
|
||
at.sidebar.toggle[0].set_value(False) # Use Streaming = False | ||
at.chat_input[0].set_value(PROMPT).run() | ||
print(at) | ||
assert at.chat_message[0].avatar == "user" | ||
assert at.chat_message[0].markdown[0].value == PROMPT | ||
assert at.chat_message[1].avatar == "assistant" | ||
assert at.chat_message[1].markdown[0].value == RESPONSE | ||
assert not at.exception | ||
|
||
|
||
def test_app_settings(mock_agent_client): | ||
"""Test the full app - happy path""" | ||
at = AppTest.from_file("../../src/streamlit_app.py").run() | ||
|
||
PROMPT = "Know any jokes?" | ||
RESPONSE = "Sure! Here's a joke:" | ||
|
||
mock_agent_client.ainvoke = AsyncMock( | ||
return_value=ChatMessage(type="ai", content=RESPONSE), | ||
) | ||
|
||
at.sidebar.toggle[0].set_value(False) # Use Streaming = False | ||
at.sidebar.radio[0].set_value("Claude 3 Haiku (streaming)") | ||
assert mock_agent_client.agent == "research-assistant" | ||
at.sidebar.selectbox[0].set_value("chatbot") | ||
at.chat_input[0].set_value(PROMPT).run() | ||
print(at) | ||
|
||
# Basic checks | ||
assert at.chat_message[0].avatar == "user" | ||
assert at.chat_message[0].markdown[0].value == PROMPT | ||
assert at.chat_message[1].avatar == "assistant" | ||
assert at.chat_message[1].markdown[0].value == RESPONSE | ||
|
||
# Check the args match the settings | ||
assert mock_agent_client.agent == "chatbot" | ||
mock_agent_client.ainvoke.assert_called_with( | ||
message=PROMPT, | ||
model=AnthropicModelName.HAIKU_3, | ||
thread_id="test session id", | ||
) | ||
assert not at.exception | ||
|
||
|
||
def test_app_thread_id_history(mock_agent_client): | ||
"""Test the thread_id is generated""" | ||
|
||
at = AppTest.from_file("../../src/streamlit_app.py").run() | ||
assert at.sidebar.markdown[-2].value == "Thread ID: **test session id**" | ||
|
||
# Reset and set thread_id | ||
at = AppTest.from_file("../../src/streamlit_app.py") | ||
at.query_params["thread_id"] = "1234" | ||
HISTORY = [ | ||
ChatMessage(type="human", content="What is the weather?"), | ||
ChatMessage(type="ai", content="The weather is sunny."), | ||
] | ||
mock_agent_client.get_history.return_value = ChatHistory(messages=HISTORY) | ||
at.run() | ||
print(at) | ||
assert at.sidebar.markdown[-2].value == "Thread ID: **1234**" | ||
mock_agent_client.get_history.assert_called_with(thread_id="1234") | ||
assert at.chat_message[0].avatar == "user" | ||
assert at.chat_message[0].markdown[0].value == "What is the weather?" | ||
assert at.chat_message[1].avatar == "assistant" | ||
assert at.chat_message[1].markdown[0].value == "The weather is sunny." | ||
assert not at.exception | ||
|
||
|
||
def test_app_feedback(mock_agent_client): | ||
"""TODO: Can't figure out how to interact with st.feedback""" | ||
|
||
pass | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_app_streaming(mock_agent_client): | ||
"""Test the app with streaming enabled - including tool messages""" | ||
at = AppTest.from_file("../../src/streamlit_app.py").run() | ||
|
||
# Setup mock streaming response | ||
PROMPT = "What is 6 * 7?" | ||
ai_with_tool = ChatMessage( | ||
type="ai", | ||
content="", | ||
tool_calls=[{"name": "calculator", "id": "test_call_id", "args": {"expression": "6 * 7"}}], | ||
) | ||
tool_message = ChatMessage(type="tool", content="42", tool_call_id="test_call_id") | ||
final_ai_message = ChatMessage(type="ai", content="The answer is 42") | ||
|
||
messages = [ai_with_tool, tool_message, final_ai_message] | ||
|
||
async def amessage_iter() -> AsyncGenerator[ChatMessage, None]: | ||
for m in messages: | ||
yield m | ||
|
||
mock_agent_client.astream = Mock(return_value=amessage_iter()) | ||
|
||
at.toggle[0].set_value(True) # Use Streaming = True | ||
at.chat_input[0].set_value(PROMPT).run() | ||
print(at) | ||
|
||
assert at.chat_message[0].avatar == "user" | ||
assert at.chat_message[0].markdown[0].value == PROMPT | ||
response = at.chat_message[1] | ||
tool_status = response.status[0] | ||
assert response.avatar == "assistant" | ||
assert tool_status.label == "Tool Call: calculator" | ||
assert tool_status.icon == ":material/check:" | ||
assert tool_status.markdown[0].value == "Input:" | ||
assert tool_status.json[0].value == '{"expression": "6 * 7"}' | ||
assert tool_status.markdown[1].value == "Output:" | ||
assert tool_status.markdown[2].value == "42" | ||
assert response.markdown[-1].value == "The answer is 42" | ||
assert not at.exception |