diff --git a/pyproject.toml b/pyproject.toml index d5d169a..09b9884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,13 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pre-commit", "pytest", "pytest-env", "ruff"] +dev = [ + "pre-commit", + "pytest", + "pytest-env", + "pytest-asyncio", + "ruff", +] [tool.ruff] line-length = 100 @@ -53,6 +59,7 @@ extend-select = ["I", "U"] [tool.pytest.ini_options] pythonpath = ["src"] +asyncio_default_fixture_loop_scope = "function" [tool.pytest_env] OPENAI_API_KEY = "sk-fake-openai-key" diff --git a/src/core/llm.py b/src/core/llm.py index 9b2a57c..08f112e 100644 --- a/src/core/llm.py +++ b/src/core/llm.py @@ -40,6 +40,8 @@ def get_model(model_name: AllModelEnum, /) -> ModelT: # NOTE: models with streaming=True will send tokens as they are generated # if the /stream endpoint is called with stream_tokens=True (the default) api_model_name = _MODEL_TABLE.get(model_name) + if not api_model_name: + raise ValueError(f"Unsupported model: {model_name}") if model_name in OpenAIModelName: return ChatOpenAI(model=api_model_name, temperature=0.5, streaming=True) @@ -55,4 +57,3 @@ def get_model(model_name: AllModelEnum, /) -> ModelT: return ChatBedrock(model_id=api_model_name, temperature=0.5) if model_name in FakeModelName: return FakeListChatModel(responses=["This is a test response from the fake model."]) - raise ValueError(f"Unsupported model: {model_name}") diff --git a/src/schema/schema.py b/src/schema/schema.py index 7800829..bc14f27 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -13,7 +13,7 @@ class UserInput(BaseModel): description="User input to the agent.", examples=["What is the weather in Tokyo?"], ) - model: SerializeAsAny[AllModelEnum] = Field( + model: SerializeAsAny[AllModelEnum] | None = Field( title="Model", description="LLM Model to use for the agent.", default="gpt-4o-mini", diff --git a/src/service/service.py b/src/service/service.py index b078954..130b45d 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -39,17 +39,17 @@ def verify_bearer( http_auth: Annotated[ - HTTPAuthorizationCredentials, - Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.")), + HTTPAuthorizationCredentials | None, + Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.", auto_error=False)), ], ) -> None: - if http_auth.credentials != settings.AUTH_SECRET.get_secret_value(): + if not settings.AUTH_SECRET: + return + auth_secret = settings.AUTH_SECRET.get_secret_value() + if not http_auth or http_auth.credentials != auth_secret: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) -bearer_depend = [Depends(verify_bearer)] if settings.AUTH_SECRET else None - - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Construct agent with Sqlite checkpointer @@ -62,7 +62,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app = FastAPI(lifespan=lifespan) -router = APIRouter(dependencies=bearer_depend) +router = APIRouter(dependencies=[Depends(verify_bearer)]) def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]: diff --git a/tests/client/conftest.py b/tests/client/conftest.py new file mode 100644 index 0000000..a9bbbb4 --- /dev/null +++ b/tests/client/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from client import AgentClient + + +@pytest.fixture +def agent_client(mock_env): + """Fixture for creating a test client with a clean environment.""" + return AgentClient(base_url="http://test", agent="test-agent") diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 0000000..817ffd7 --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,279 @@ +import json +import os +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from httpx import Response + +from client import AgentClient +from schema import ChatHistory, ChatMessage + + +def test_init(mock_env): + """Test client initialization with different parameters.""" + # Test default values + client = AgentClient() + assert client.base_url == "http://localhost" + assert client.agent == "research-assistant" + assert client.timeout is None + + # Test custom values + client = AgentClient( + base_url="http://test", + agent="custom-agent", + timeout=30.0, + ) + assert client.base_url == "http://test" + assert client.agent == "custom-agent" + assert client.timeout == 30.0 + + +def test_headers(mock_env): + """Test header generation with and without auth.""" + # Test without auth + client = AgentClient() + assert client._headers == {} + + # Test with auth + with patch.dict(os.environ, {"AUTH_SECRET": "test-secret"}, clear=True): + client = AgentClient() + assert client._headers == {"Authorization": "Bearer test-secret"} + + +def test_invoke(agent_client): + """Test synchronous invocation.""" + QUESTION = "What is the weather?" + ANSWER = "The weather is sunny." + + # Mock successful response + mock_response = Response( + 200, + json={"type": "ai", "content": ANSWER}, + ) + with patch("httpx.post", return_value=mock_response): + response = agent_client.invoke(QUESTION) + assert isinstance(response, ChatMessage) + assert response.type == "ai" + assert response.content == ANSWER + + # Test with model and thread_id + with patch("httpx.post", return_value=mock_response) as mock_post: + response = agent_client.invoke( + QUESTION, + model="gpt-4o", + thread_id="test-thread", + ) + assert isinstance(response, ChatMessage) + # Verify request + args, kwargs = mock_post.call_args + assert kwargs["json"]["message"] == QUESTION + assert kwargs["json"]["model"] == "gpt-4o" + assert kwargs["json"]["thread_id"] == "test-thread" + + # Test error response + error_response = Response(500, text="Internal Server Error") + with patch("httpx.post", return_value=error_response): + with pytest.raises(Exception) as exc: + agent_client.invoke(QUESTION) + assert "Error: 500" in str(exc.value) + + +@pytest.mark.asyncio +async def test_ainvoke(agent_client): + """Test asynchronous invocation.""" + QUESTION = "What is the weather?" + ANSWER = "The weather is sunny." + + # Test successful response + mock_response = Response(200, json={"type": "ai", "content": ANSWER}) + with patch("httpx.AsyncClient.post", return_value=mock_response): + response = await agent_client.ainvoke(QUESTION) + assert isinstance(response, ChatMessage) + assert response.type == "ai" + assert response.content == ANSWER + + # Test with model and thread_id + with patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post: + response = await agent_client.ainvoke( + QUESTION, + model="gpt-4o", + thread_id="test-thread", + ) + assert isinstance(response, ChatMessage) + assert response.type == "ai" + assert response.content == ANSWER + # Verify request + args, kwargs = mock_post.call_args + assert kwargs["json"]["message"] == QUESTION + assert kwargs["json"]["model"] == "gpt-4o" + assert kwargs["json"]["thread_id"] == "test-thread" + + # Test error response + with patch("httpx.AsyncClient.post", return_value=Response(500, text="Internal Server Error")): + with pytest.raises(Exception) as exc: + await agent_client.ainvoke(QUESTION) + assert "Error: 500" in str(exc.value) + + +def test_stream(agent_client): + """Test synchronous streaming.""" + QUESTION = "What is the weather?" + TOKENS = ["The", " weather", " is", " sunny", "."] + FINAL_ANSWER = "The weather is sunny." + + # Create mock response with streaming events + events = ( + [f"data: {json.dumps({'type': 'token', 'content': token})}" for token in TOKENS] + + [ + f"data: {json.dumps({'type': 'message', 'content': {'type': 'ai', 'content': FINAL_ANSWER}})}" + ] + + ["data: [DONE]"] + ) + + # Mock the streaming response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = events + mock_response.__enter__ = Mock(return_value=mock_response) + mock_response.__exit__ = Mock(return_value=None) + + with patch("httpx.stream", return_value=mock_response): + # Collect all streamed responses + responses = list(agent_client.stream(QUESTION)) + + # Verify tokens were streamed + assert len(responses) == len(TOKENS) + 1 # tokens + final message + for i, token in enumerate(TOKENS): + assert responses[i] == token + + # Verify final message + final_message = responses[-1] + assert isinstance(final_message, ChatMessage) + assert final_message.type == "ai" + assert final_message.content == FINAL_ANSWER + + # Test error response + error_response = Mock() + error_response.status_code = 500 + error_response.text = "Internal Server Error" + error_response.__enter__ = Mock(return_value=error_response) + error_response.__exit__ = Mock(return_value=None) + with patch("httpx.stream", return_value=error_response): + with pytest.raises(Exception) as exc: + list(agent_client.stream(QUESTION)) + assert "Error: 500" in str(exc.value) + + +@pytest.mark.asyncio +async def test_astream(agent_client): + """Test asynchronous streaming.""" + QUESTION = "What is the weather?" + TOKENS = ["The", " weather", " is", " sunny", "."] + FINAL_ANSWER = "The weather is sunny." + + # Create mock response with streaming events + events = ( + [f"data: {json.dumps({'type': 'token', 'content': token})}" for token in TOKENS] + + [ + f"data: {json.dumps({'type': 'message', 'content': {'type': 'ai', 'content': FINAL_ANSWER}})}" + ] + + ["data: [DONE]"] + ) + + # Create an async iterator for the events + async def async_events(): + for event in events: + yield event + + # Mock the streaming response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.aiter_lines = Mock(return_value=async_events()) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.stream = Mock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_client): + # Collect all streamed responses + responses = [] + async for response in agent_client.astream(QUESTION): + responses.append(response) + + # Verify tokens were streamed + assert len(responses) == len(TOKENS) + 1 # tokens + final message + for i, token in enumerate(TOKENS): + assert responses[i] == token + + # Verify final message + final_message = responses[-1] + assert isinstance(final_message, ChatMessage) + assert final_message.type == "ai" + assert final_message.content == FINAL_ANSWER + + # Test error response + error_response = AsyncMock() + error_response.status_code = 500 + error_response.text = "Internal Server Error" + error_response.__aenter__ = AsyncMock(return_value=error_response) + + mock_client.stream.return_value = error_response + + with patch("httpx.AsyncClient", return_value=mock_client): + with pytest.raises(Exception) as exc: + async for _ in agent_client.astream(QUESTION): + pass + assert "Error: 500" in str(exc.value) + + +@pytest.mark.asyncio +async def test_acreate_feedback(agent_client): + """Test asynchronous feedback creation.""" + RUN_ID = "test-run" + KEY = "test-key" + SCORE = 0.8 + KWARGS = {"comment": "Great response!"} + + # Test successful response + with patch("httpx.AsyncClient.post", return_value=Response(200, json={})) as mock_post: + await agent_client.acreate_feedback(RUN_ID, KEY, SCORE, KWARGS) + # Verify request + args, kwargs = mock_post.call_args + assert kwargs["json"]["run_id"] == RUN_ID + assert kwargs["json"]["key"] == KEY + assert kwargs["json"]["score"] == SCORE + assert kwargs["json"]["kwargs"] == KWARGS + + # Test error response + with patch("httpx.AsyncClient.post", return_value=Response(500, text="Internal Server Error")): + with pytest.raises(Exception) as exc: + await agent_client.acreate_feedback(RUN_ID, KEY, SCORE) + assert "Error: 500" in str(exc.value) + + +def test_get_history(agent_client): + """Test chat history retrieval.""" + THREAD_ID = "test-thread" + HISTORY = { + "messages": [ + {"type": "human", "content": "What is the weather?"}, + {"type": "ai", "content": "The weather is sunny."}, + ] + } + + # Mock successful response + mock_response = Response(200, json=HISTORY) + with patch("httpx.post", return_value=mock_response): + history = agent_client.get_history(THREAD_ID) + assert isinstance(history, ChatHistory) + assert len(history.messages) == 2 + assert history.messages[0].type == "human" + assert history.messages[1].type == "ai" + + # Test error response + error_response = Response(500, text="Internal Server Error") + with patch("httpx.post", return_value=error_response): + with pytest.raises(Exception) as exc: + agent_client.get_history(THREAD_ID) + assert "Error: 500" in str(exc.value) diff --git a/tests/conftest.py b/tests/conftest.py index b5e8729..d4fd1a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +import os +from unittest.mock import patch + import pytest @@ -17,3 +20,10 @@ def pytest_collection_modifyitems(config, items): for item in items: if "docker" in item.keywords: item.add_marker(skip_docker) + + +@pytest.fixture +def mock_env(): + """Fixture to ensure environment is clean for each test.""" + with patch.dict(os.environ, {}, clear=True): + yield diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py new file mode 100644 index 0000000..8163243 --- /dev/null +++ b/tests/core/test_llm.py @@ -0,0 +1,62 @@ +import os +from unittest.mock import patch + +import pytest +from langchain_anthropic import ChatAnthropic +from langchain_community.chat_models import FakeListChatModel +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI + +from core.llm import get_model +from schema.models import ( + AnthropicModelName, + FakeModelName, + GroqModelName, + OpenAIModelName, +) + + +def test_get_model_openai(): + with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}): + model = get_model(OpenAIModelName.GPT_4O_MINI) + assert isinstance(model, ChatOpenAI) + assert model.model_name == "gpt-4o-mini" + assert model.temperature == 0.5 + assert model.streaming is True + + +def test_get_model_anthropic(): + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}): + model = get_model(AnthropicModelName.HAIKU_3) + assert isinstance(model, ChatAnthropic) + assert model.model == "claude-3-haiku-20240307" + assert model.temperature == 0.5 + assert model.streaming is True + + +def test_get_model_groq(): + with patch.dict(os.environ, {"GROQ_API_KEY": "test_key"}): + model = get_model(GroqModelName.LLAMA_31_8B) + assert isinstance(model, ChatGroq) + assert model.model_name == "llama-3.1-8b-instant" + assert model.temperature == 0.5 + + +def test_get_model_groq_guard(): + with patch.dict(os.environ, {"GROQ_API_KEY": "test_key"}): + model = get_model(GroqModelName.LLAMA_GUARD_3_8B) + assert isinstance(model, ChatGroq) + assert model.model_name == "llama-guard-3-8b" + assert model.temperature < 0.01 + + +def test_get_model_fake(): + model = get_model(FakeModelName.FAKE) + assert isinstance(model, FakeListChatModel) + assert model.responses == ["This is a test response from the fake model."] + + +def test_get_model_invalid(): + with pytest.raises(ValueError, match="Unsupported model:"): + # Using type: ignore since we're intentionally testing invalid input + get_model("invalid_model") # type: ignore diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py new file mode 100644 index 0000000..5145e49 --- /dev/null +++ b/tests/core/test_settings.py @@ -0,0 +1,62 @@ +import os +from unittest.mock import patch + +import pytest +from pydantic import SecretStr, ValidationError + +from core.settings import Settings, check_str_is_http +from schema.models import AnthropicModelName, OpenAIModelName + + +def test_check_str_is_http(): + # Test valid HTTP URLs + assert check_str_is_http("http://example.com/") == "http://example.com/" + assert check_str_is_http("https://api.test.com/") == "https://api.test.com/" + + # Test invalid URLs + with pytest.raises(ValidationError): + check_str_is_http("not_a_url") + with pytest.raises(ValidationError): + check_str_is_http("ftp://invalid.com") + + +def test_settings_default_values(): + settings = Settings(_env_file=None) + assert settings.HOST == "0.0.0.0" + assert settings.PORT == 80 + assert settings.USE_AWS_BEDROCK is False + assert settings.USE_FAKE_MODEL is False + + +def test_settings_no_api_keys(): + # Test that settings raises error when no API keys are provided + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="At least one LLM API key must be provided"): + _ = Settings(_env_file=None) + + +def test_settings_with_openai_key(): + with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}, clear=True): + settings = Settings(_env_file=None) + assert settings.OPENAI_API_KEY == SecretStr("test_key") + assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI + + +def test_settings_with_anthropic_key(): + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}, clear=True): + settings = Settings(_env_file=None) + assert settings.ANTHROPIC_API_KEY == SecretStr("test_key") + assert settings.DEFAULT_MODEL == AnthropicModelName.HAIKU_3 + + +def test_settings_base_url(): + settings = Settings(HOST="localhost", PORT=8000, _env_file=None) + assert settings.BASE_URL == "http://localhost:8000" + + +def test_settings_is_dev(): + settings = Settings(MODE="dev", _env_file=None) + assert settings.is_dev() is True + + settings = Settings(MODE="prod", _env_file=None) + assert settings.is_dev() is False diff --git a/tests/service/conftest.py b/tests/service/conftest.py new file mode 100644 index 0000000..6fe0608 --- /dev/null +++ b/tests/service/conftest.py @@ -0,0 +1,31 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi.testclient import TestClient +from langchain_core.messages import AIMessage + +from agents import DEFAULT_AGENT +from service import app + + +@pytest.fixture +def test_client(): + """Fixture to create a FastAPI test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_agent(): + """Fixture to create a mock agent that can be configured for different test scenarios.""" + agent_mock = AsyncMock() + agent_mock.ainvoke = AsyncMock(return_value={"messages": [AIMessage(content="Test response")]}) + agent_mock.get_state = Mock() # Default empty mock for get_state + with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): + yield agent_mock + + +@pytest.fixture +def mock_settings(mock_env): + """Fixture to ensure settings are clean for each test.""" + with patch("service.service.settings") as mock_settings: + yield mock_settings diff --git a/tests/service/test_auth.py b/tests/service/test_auth.py new file mode 100644 index 0000000..f7d989d --- /dev/null +++ b/tests/service/test_auth.py @@ -0,0 +1,42 @@ +from pydantic import SecretStr + + +def test_no_auth_secret(mock_settings, mock_agent, test_client): + """Test that when AUTH_SECRET is not set, all requests are allowed""" + mock_settings.AUTH_SECRET = None + response = test_client.post( + "/invoke", + json={"message": "test"}, + headers={"Authorization": "Bearer any-token"}, + ) + assert response.status_code == 200 + + # Should also work without any auth header + response = test_client.post("/invoke", json={"message": "test"}) + assert response.status_code == 200 + + +def test_auth_secret_correct(mock_settings, mock_agent, test_client): + """Test that when AUTH_SECRET is set, requests with correct token are allowed""" + mock_settings.AUTH_SECRET = SecretStr("test-secret") + response = test_client.post( + "/invoke", + json={"message": "test"}, + headers={"Authorization": "Bearer test-secret"}, + ) + assert response.status_code == 200 + + +def test_auth_secret_incorrect(mock_settings, mock_agent, test_client): + """Test that when AUTH_SECRET is set, requests with wrong token are rejected""" + mock_settings.AUTH_SECRET = SecretStr("test-secret") + response = test_client.post( + "/invoke", + json={"message": "test"}, + headers={"Authorization": "Bearer wrong-secret"}, + ) + assert response.status_code == 401 + + # Should also reject requests with no auth header + response = test_client.post("/invoke", json={"message": "test"}) + assert response.status_code == 401 diff --git a/tests/service/test_service.py b/tests/service/test_service.py index 1c8bddf..a6bbe4e 100644 --- a/tests/service/test_service.py +++ b/tests/service/test_service.py @@ -1,31 +1,25 @@ -from unittest.mock import AsyncMock, Mock, patch +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch import langsmith -from fastapi.testclient import TestClient +import pytest from langchain_core.messages import AIMessage, HumanMessage from langgraph.pregel.types import StateSnapshot -from agents import DEFAULT_AGENT from schema import ChatHistory, ChatMessage -from service import app -test_client = TestClient(app) - -def test_invoke() -> None: +def test_invoke(test_client, mock_agent) -> None: QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." - agent_response = {"messages": [AIMessage(content=ANSWER)]} - agent_mock = AsyncMock() - agent_mock.ainvoke = AsyncMock(return_value=agent_response) + mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]} - with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): - with test_client as c: - response = c.post("/invoke", json={"message": QUESTION}) - assert response.status_code == 200 + response = test_client.post("/invoke", json={"message": QUESTION}) + assert response.status_code == 200 - agent_mock.ainvoke.assert_awaited_once() - input_message = agent_mock.ainvoke.await_args.kwargs["input"]["messages"][0] + mock_agent.ainvoke.assert_awaited_once() + input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] assert input_message.content == QUESTION output = ChatMessage.model_validate(response.json()) @@ -33,8 +27,66 @@ def test_invoke() -> None: assert output.content == ANSWER +def test_invoke_custom_agent(test_client, mock_agent) -> None: + """Test that /invoke works with a custom agent_id path parameter.""" + CUSTOM_AGENT = "custom_agent" + DEFAULT_AGENT = "default_agent" + QUESTION = "What is the weather in Tokyo?" + CUSTOM_ANSWER = "The weather in Tokyo is sunny." + DEFAULT_ANSWER = "This is from the default agent." + + # Create a separate mock for the default agent + default_mock = AsyncMock() + default_mock.ainvoke.return_value = {"messages": [AIMessage(content=DEFAULT_ANSWER)]} + + # Configure our custom mock agent + mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=CUSTOM_ANSWER)]} + + # Patch the agents dictionary to include both agents + with patch("service.service.agents", {CUSTOM_AGENT: mock_agent, DEFAULT_AGENT: default_mock}): + response = test_client.post(f"/{CUSTOM_AGENT}/invoke", json={"message": QUESTION}) + assert response.status_code == 200 + + # Verify custom agent was called and default wasn't + mock_agent.ainvoke.assert_awaited_once() + default_mock.ainvoke.assert_not_awaited() + + input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] + assert input_message.content == QUESTION + + output = ChatMessage.model_validate(response.json()) + assert output.type == "ai" + assert output.content == CUSTOM_ANSWER # Verify we got the custom agent's response + + +def test_invoke_model_param(test_client, mock_agent) -> None: + """Test that the model parameter is correctly passed to the agent.""" + QUESTION = "What is the weather in Tokyo?" + ANSWER = "The weather in Tokyo is sunny." + CUSTOM_MODEL = "claude-3.5-sonnet" + mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]} + + response = test_client.post("/invoke", json={"message": QUESTION, "model": CUSTOM_MODEL}) + assert response.status_code == 200 + + # Verify the model was passed correctly in the config + mock_agent.ainvoke.assert_awaited_once() + config = mock_agent.ainvoke.await_args.kwargs["config"] + assert config["configurable"]["model"] == CUSTOM_MODEL + + # Verify the response is still correct + output = ChatMessage.model_validate(response.json()) + assert output.type == "ai" + assert output.content == ANSWER + + # Verify an invalid model throws a validation error + INVALID_MODEL = "gpt-7-notreal" + response = test_client.post("/invoke", json={"message": QUESTION, "model": INVALID_MODEL}) + assert response.status_code == 422 + + @patch("service.service.LangsmithClient") -def test_feedback(mock_client: langsmith.Client) -> None: +def test_feedback(mock_client: langsmith.Client, test_client) -> None: ls_instance = mock_client.return_value ls_instance.create_feedback.return_value = None body = { @@ -52,33 +104,133 @@ def test_feedback(mock_client: langsmith.Client) -> None: ) -def test_history() -> None: +def test_history(test_client, mock_agent) -> None: QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." user_question = HumanMessage(content=QUESTION) agent_response = AIMessage(content=ANSWER) - agent_mock = AsyncMock() - agent_mock.get_state = Mock( - return_value=StateSnapshot( - values={"messages": [user_question, agent_response]}, - next=(), - config={}, - metadata=None, - created_at=None, - parent_config=None, - tasks=(), - ) + mock_agent.get_state.return_value = StateSnapshot( + values={"messages": [user_question, agent_response]}, + next=(), + config={}, + metadata=None, + created_at=None, + parent_config=None, + tasks=(), ) - with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): - with test_client as c: - response = c.post( - "/history", json={"thread_id": "7bcc7cc1-99d7-4b1d-bdb5-e6f90ed44de6"} - ) - assert response.status_code == 200 + response = test_client.post( + "/history", json={"thread_id": "7bcc7cc1-99d7-4b1d-bdb5-e6f90ed44de6"} + ) + assert response.status_code == 200 output = ChatHistory.model_validate(response.json()) assert output.messages[0].type == "human" assert output.messages[0].content == QUESTION assert output.messages[1].type == "ai" assert output.messages[1].content == ANSWER + + +@pytest.mark.asyncio +async def test_stream(test_client, mock_agent) -> None: + """Test streaming tokens and messages.""" + QUESTION = "What is the weather in Tokyo?" + TOKENS = ["The", " weather", " in", " Tokyo", " is", " sunny", "."] + FINAL_ANSWER = "The weather in Tokyo is sunny." + + # Configure mock to use our async iterator function + events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": SimpleNamespace(content=token)}, + "tags": [], + } + for token in TOKENS + ] + [ + { + "event": "on_chain_end", + "data": {"output": {"messages": [AIMessage(content=FINAL_ANSWER)]}}, + "tags": ["graph:step:1"], + } + ] + + async def mock_astream_events(**kwargs): + for event in events: + yield event + + mock_agent.astream_events = mock_astream_events + + # Make request with streaming + with test_client.stream( + "POST", "/stream", json={"message": QUESTION, "stream_tokens": True} + ) as response: + assert response.status_code == 200 + + # Collect all SSE messages + messages = [] + for line in response.iter_lines(): + if line and line.strip() != "data: [DONE]": # Skip [DONE] message + messages.append(json.loads(line.lstrip("data: "))) + + # Verify streamed tokens + token_messages = [msg for msg in messages if msg["type"] == "token"] + assert len(token_messages) == len(TOKENS) + for i, msg in enumerate(token_messages): + assert msg["content"] == TOKENS[i] + + # Verify final message + final_messages = [msg for msg in messages if msg["type"] == "message"] + assert len(final_messages) == 1 + assert final_messages[0]["content"]["content"] == FINAL_ANSWER + assert final_messages[0]["content"]["type"] == "ai" + + +@pytest.mark.asyncio +async def test_stream_no_tokens(test_client, mock_agent) -> None: + """Test streaming without tokens.""" + QUESTION = "What is the weather in Tokyo?" + FINAL_ANSWER = "The weather in Tokyo is sunny." + + # Configure mock to use our async iterator function + events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": SimpleNamespace(content=token)}, + "tags": [], + } + for token in ["The", " weather", " in", " Tokyo", " is", " sunny", "."] + ] + [ + { + "event": "on_chain_end", + "data": {"output": {"messages": [AIMessage(content=FINAL_ANSWER)]}}, + "tags": ["graph:step:1"], + } + ] + + async def mock_astream_events(**kwargs): + for event in events: + yield event + + mock_agent.astream_events = mock_astream_events + + # Make request with streaming disabled + with test_client.stream( + "POST", "/stream", json={"message": QUESTION, "stream_tokens": False} + ) as response: + assert response.status_code == 200 + + # Collect all SSE messages + messages = [] + for line in response.iter_lines(): + if line and line.strip() != "data: [DONE]": # Skip [DONE] message + messages.append(json.loads(line.lstrip("data: "))) + + # Verify no token messages + token_messages = [msg for msg in messages if msg["type"] == "token"] + assert len(token_messages) == 0 + + # Verify final message + final_messages = [msg for msg in messages if msg["type"] == "message"] + assert len(final_messages) == 1 + assert final_messages[0]["content"]["content"] == FINAL_ANSWER + assert final_messages[0]["content"]["type"] == "ai" diff --git a/uv.lock b/uv.lock index 317a222..6bd534c 100644 --- a/uv.lock +++ b/uv.lock @@ -41,6 +41,7 @@ dependencies = [ dev = [ { name = "pre-commit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-env" }, { name = "ruff" }, ] @@ -67,6 +68,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = "~=2.6.1" }, { name = "pyowm", specifier = "~=3.3.0" }, { name = "pytest", marker = "extra == 'dev'" }, + { name = "pytest-asyncio", marker = "extra == 'dev'" }, { name = "pytest-env", marker = "extra == 'dev'" }, { name = "python-dotenv", specifier = "~=1.0.1" }, { name = "ruff", marker = "extra == 'dev'" }, @@ -1754,6 +1756,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/c6cf50ce320cf8611df7a1254d86233b3df7cc07f9b5f5cbcb82e08aa534/pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276", size = 49855 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 }, +] + [[package]] name = "pytest-env" version = "1.1.5"