Skip to content

Commit

Permalink
feat: base url retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Jan 13, 2025
1 parent 739c88a commit fff0a3f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
14 changes: 12 additions & 2 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class RunMetadata(TypedDict, total=False):
provider: str
model: str
model_params: Dict[str, Any]
base_url: str
start_time: float
end_time: float

Expand Down Expand Up @@ -105,7 +106,7 @@ def on_chat_model_start(
):
self._set_parent_of_run(run_id, parent_run_id)
input = [_convert_message_to_dict(message) for row in messages for message in row]
self._set_run_metadata(run_id, input, **kwargs)
self._set_run_metadata(serialized, run_id, input, **kwargs)

def on_llm_start(
self,
Expand All @@ -117,7 +118,7 @@ def on_llm_start(
**kwargs: Any,
):
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(run_id, prompts, **kwargs)
self._set_run_metadata(serialized, run_id, prompts, **kwargs)

def on_chain_end(
self,
Expand Down Expand Up @@ -171,6 +172,7 @@ def on_llm_end(
"$ai_latency": latency,
"$ai_trace_id": trace_id,
"$ai_posthog_properties": self._properties,
"$ai_base_url": run.get("base_url"),
}
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
Expand Down Expand Up @@ -215,6 +217,7 @@ def on_llm_error(
"$ai_latency": latency,
"$ai_trace_id": trace_id,
"$ai_posthog_properties": self._properties,
"$ai_base_url": run.get("base_url"),
}
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
Expand Down Expand Up @@ -251,6 +254,7 @@ def _find_root_run(self, run_id: UUID) -> UUID:

def _set_run_metadata(
self,
serialized: Dict[str, Any],
run_id: UUID,
messages: Union[List[Dict[str, Any]], List[str]],
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -268,6 +272,12 @@ def _set_run_metadata(
run["model"] = model
if provider := metadata.get("ls_provider"):
run["provider"] = provider
try:
base_url = serialized["kwargs"]["openai_api_base"]
if base_url is not None:
run["base_url"] = base_url
except KeyError:
pass
self._runs[run_id] = run

def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
Expand Down
18 changes: 18 additions & 0 deletions posthog/test/ai/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_metadata_capture(mock_client):
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_run_metadata(
{"kwargs": {"openai_api_base": "https://us.posthog.com"}},
run_id,
messages=[{"role": "user", "content": "Who won the world series in 2020?"}],
invocation_params={"temperature": 0.5},
Expand All @@ -72,6 +73,7 @@ def test_metadata_capture(mock_client):
"start_time": 1234567890,
"model_params": {"temperature": 0.5},
"provider": "posthog",
"base_url": "https://us.posthog.com",
}
assert callbacks._runs[run_id] == expected
with patch("time.time", return_value=1234567891):
Expand Down Expand Up @@ -577,3 +579,19 @@ async def test_async_openai_streaming(mock_client):
assert first_call_props["$ai_http_status"] == 200
assert first_call_props["$ai_input_tokens"] == 20
assert first_call_props["$ai_output_tokens"] == 1


def test_base_url_retrieval(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain = prompt | ChatOpenAI(
api_key="test",
model="posthog-mini",
base_url="https://test.posthog.com",
)
callbacks = CallbackHandler(mock_client)
with pytest.raises(Exception):
chain.invoke({}, config={"callbacks": [callbacks]})

assert mock_client.capture.call_count == 1
call = mock_client.capture.call_args[1]
assert call["properties"]["$ai_base_url"] == "https://test.posthog.com"

0 comments on commit fff0a3f

Please sign in to comment.