From fff0a3feabda406bd4e18c007b323c05c8c39b3b Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Mon, 13 Jan 2025 16:59:51 +0100 Subject: [PATCH] feat: base url retrieval --- posthog/ai/langchain/callbacks.py | 14 ++++++++++++-- posthog/test/ai/langchain/test_callbacks.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 1524679..dd0dc1e 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -34,6 +34,7 @@ class RunMetadata(TypedDict, total=False): provider: str model: str model_params: Dict[str, Any] + base_url: str start_time: float end_time: float @@ -105,7 +106,7 @@ def on_chat_model_start( ): self._set_parent_of_run(run_id, parent_run_id) input = [_convert_message_to_dict(message) for row in messages for message in row] - self._set_run_metadata(run_id, input, **kwargs) + self._set_run_metadata(serialized, run_id, input, **kwargs) def on_llm_start( self, @@ -117,7 +118,7 @@ def on_llm_start( **kwargs: Any, ): self._set_parent_of_run(run_id, parent_run_id) - self._set_run_metadata(run_id, prompts, **kwargs) + self._set_run_metadata(serialized, run_id, prompts, **kwargs) def on_chain_end( self, @@ -171,6 +172,7 @@ def on_llm_end( "$ai_latency": latency, "$ai_trace_id": trace_id, "$ai_posthog_properties": self._properties, + "$ai_base_url": run.get("base_url"), } if self._distinct_id is None: event_properties["$process_person_profile"] = False @@ -215,6 +217,7 @@ def on_llm_error( "$ai_latency": latency, "$ai_trace_id": trace_id, "$ai_posthog_properties": self._properties, + "$ai_base_url": run.get("base_url"), } if self._distinct_id is None: event_properties["$process_person_profile"] = False @@ -251,6 +254,7 @@ def _find_root_run(self, run_id: UUID) -> UUID: def _set_run_metadata( self, + serialized: Dict[str, Any], run_id: UUID, messages: Union[List[Dict[str, Any]], List[str]], metadata: Optional[Dict[str, Any]] = None, @@ -268,6 +272,12 @@ def _set_run_metadata( run["model"] = model if provider := metadata.get("ls_provider"): run["provider"] = provider + try: + base_url = serialized["kwargs"]["openai_api_base"] + if base_url is not None: + run["base_url"] = base_url + except KeyError: + pass self._runs[run_id] = run def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]: diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index a3aa5e4..007f983 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -61,6 +61,7 @@ def test_metadata_capture(mock_client): run_id = uuid.uuid4() with patch("time.time", return_value=1234567890): callbacks._set_run_metadata( + {"kwargs": {"openai_api_base": "https://us.posthog.com"}}, run_id, messages=[{"role": "user", "content": "Who won the world series in 2020?"}], invocation_params={"temperature": 0.5}, @@ -72,6 +73,7 @@ def test_metadata_capture(mock_client): "start_time": 1234567890, "model_params": {"temperature": 0.5}, "provider": "posthog", + "base_url": "https://us.posthog.com", } assert callbacks._runs[run_id] == expected with patch("time.time", return_value=1234567891): @@ -577,3 +579,19 @@ async def test_async_openai_streaming(mock_client): assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 + + +def test_base_url_retrieval(mock_client): + prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) + chain = prompt | ChatOpenAI( + api_key="test", + model="posthog-mini", + base_url="https://test.posthog.com", + ) + callbacks = CallbackHandler(mock_client) + with pytest.raises(Exception): + chain.invoke({}, config={"callbacks": [callbacks]}) + + assert mock_client.capture.call_count == 1 + call = mock_client.capture.call_args[1] + assert call["properties"]["$ai_base_url"] == "https://test.posthog.com"