Skip to content

Commit

Permalink
fix: review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Jan 13, 2025
1 parent 9f7e094 commit 739c88a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 37 deletions.
4 changes: 2 additions & 2 deletions posthog/ai/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .callbacks import PosthogCallbackHandler
from .callbacks import CallbackHandler

__all__ = ["PosthogCallbackHandler"]
__all__ = ["CallbackHandler"]
9 changes: 5 additions & 4 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from posthog.ai.utils import get_model_params
from posthog.client import Client

log = logging.getLogger("posthog")


class RunMetadata(TypedDict, total=False):
messages: Union[List[Dict[str, Any]], List[str]]
Expand All @@ -39,7 +41,7 @@ class RunMetadata(TypedDict, total=False):
RunStorage = Dict[UUID, RunMetadata]


class PosthogCallbackHandler(BaseCallbackHandler):
class CallbackHandler(BaseCallbackHandler):
"""
A callback handler for LangChain that sends events to PostHog LLM Observability.
"""
Expand Down Expand Up @@ -80,7 +82,6 @@ def __init__(
self._properties = properties
self._runs = {}
self._parent_tree = {}
self.log = logging.getLogger("posthog")

def on_chain_start(
self,
Expand Down Expand Up @@ -274,7 +275,7 @@ def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
try:
run = self._runs.pop(run_id)
except KeyError:
self.log.warning(f"No run metadata found for run {run_id}")
log.warning(f"No run metadata found for run {run_id}")
return None
run["end_time"] = end_time
return run
Expand Down Expand Up @@ -395,7 +396,7 @@ def _parse_usage(response: LLMResult):


def _get_http_status(error: BaseException) -> int:
# OpenAI: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
# OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py
# Anthropic: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
# Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py
status_code = getattr(error, "status_code", getattr(error, "code", 0))
Expand Down
3 changes: 1 addition & 2 deletions posthog/ai/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

try:
import openai
import openai.resources
except ImportError:
raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'")

import openai.resources

from posthog.ai.utils import call_llm_and_track_usage, get_model_params
from posthog.client import Client as PostHogClient

Expand Down
3 changes: 1 addition & 2 deletions posthog/ai/openai/openai_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

try:
import openai
import openai.resources
except ImportError:
raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'")

import openai.resources

from posthog.ai.utils import call_llm_and_track_usage_async, get_model_params
from posthog.client import Client as PostHogClient

Expand Down
50 changes: 23 additions & 27 deletions posthog/test/ai/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain_core.runnables import RunnableLambda
from langchain_openai.chat_models import ChatOpenAI

from posthog.ai.langchain import PosthogCallbackHandler
from posthog.ai.langchain import CallbackHandler

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

Expand All @@ -24,7 +24,7 @@ def mock_client():


def test_parent_capture(mock_client):
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
parent_run_id = uuid.uuid4()
run_id = uuid.uuid4()
callbacks._set_parent_of_run(run_id, parent_run_id)
Expand All @@ -35,7 +35,7 @@ def test_parent_capture(mock_client):


def test_find_root_run(mock_client):
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
root_run_id = uuid.uuid4()
parent_run_id = uuid.uuid4()
run_id = uuid.uuid4()
Expand All @@ -47,17 +47,17 @@ def test_find_root_run(mock_client):


def test_trace_id_generation(mock_client):
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("uuid.uuid4", return_value=run_id):
assert callbacks._get_trace_id(run_id) == run_id
run_id = uuid.uuid4()
callbacks = PosthogCallbackHandler(mock_client, trace_id=run_id)
callbacks = CallbackHandler(mock_client, trace_id=run_id)
assert callbacks._get_trace_id(uuid.uuid4()) == run_id


def test_metadata_capture(mock_client):
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_run_metadata(
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_basic_chat_chain(mock_client, stream):
)
]
)
callbacks = [PosthogCallbackHandler(mock_client)]
callbacks = [CallbackHandler(mock_client)]
chain = prompt | model
if stream:
result = [m for m in chain.stream({}, config={"callbacks": callbacks})][0]
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_async_basic_chat_chain(mock_client, stream):
)
]
)
callbacks = [PosthogCallbackHandler(mock_client)]
callbacks = [CallbackHandler(mock_client)]
chain = prompt | model
if stream:
result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][0]
Expand Down Expand Up @@ -178,7 +178,7 @@ async def test_async_basic_chat_chain(mock_client, stream):
)
def test_basic_llm_chain(mock_client, Model, stream):
model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."])
callbacks: list[PosthogCallbackHandler] = [PosthogCallbackHandler(mock_client)]
callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)]

if stream:
result = "".join(
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_basic_llm_chain(mock_client, Model, stream):
)
async def test_async_basic_llm_chain(mock_client, Model, stream):
model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."])
callbacks: list[PosthogCallbackHandler] = [PosthogCallbackHandler(mock_client)]
callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)]

if stream:
result = "".join(
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_trace_id_for_multiple_chains(mock_client):
]
)
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
callbacks = [PosthogCallbackHandler(mock_client)]
callbacks = [CallbackHandler(mock_client)]
chain = prompt | model | RunnableLambda(lambda x: [x]) | model
result = chain.invoke({}, config={"callbacks": callbacks})

Expand Down Expand Up @@ -279,13 +279,13 @@ def test_trace_id_for_multiple_chains(mock_client):
def test_personless_mode(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client)]})
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client)]})
assert mock_client.capture.call_count == 1
args = mock_client.capture.call_args_list[0][1]
assert args["properties"]["$process_person_profile"] is False

id = uuid.uuid4()
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client, distinct_id=id)]})
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]})
assert mock_client.capture.call_count == 2
args = mock_client.capture.call_args_list[1][1]
assert "$process_person_profile" not in args["properties"]
Expand All @@ -295,7 +295,7 @@ def test_personless_mode(mock_client):
def test_personless_mode_exception(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
with pytest.raises(Exception):
chain.invoke({}, config={"callbacks": [callbacks]})
assert mock_client.capture.call_count == 1
Expand All @@ -304,7 +304,7 @@ def test_personless_mode_exception(mock_client):

id = uuid.uuid4()
with pytest.raises(Exception):
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client, distinct_id=id)]})
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]})
assert mock_client.capture.call_count == 2
args = mock_client.capture.call_args_list[1][1]
assert "$process_person_profile" not in args["properties"]
Expand All @@ -319,7 +319,7 @@ def test_metadata(mock_client):
)
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
callbacks = [
PosthogCallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
]
chain = prompt | model
result = chain.invoke({}, config={"callbacks": callbacks})
Expand All @@ -343,9 +343,7 @@ def test_metadata(mock_client):
def test_callbacks_logic(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
callbacks = PosthogCallbackHandler(
mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}
)
callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
chain = prompt | model

chain.invoke({}, config={"callbacks": [callbacks]})
Expand All @@ -366,7 +364,7 @@ def test_exception_in_chain(mock_client):
def runnable(_):
raise ValueError("test")

callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
with pytest.raises(ValueError):
RunnableLambda(runnable).invoke({}, config={"callbacks": [callbacks]})

Expand All @@ -378,7 +376,7 @@ def runnable(_):
def test_openai_error(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)

# 401
with pytest.raises(Exception):
Expand Down Expand Up @@ -408,9 +406,7 @@ def test_openai_chain(mock_client):
temperature=0,
max_tokens=1,
)
callbacks = PosthogCallbackHandler(
mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}
)
callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
start_time = time.time()
result = chain.invoke({}, config={"callbacks": [callbacks]})
approximate_latency = math.floor(time.time() - start_time)
Expand Down Expand Up @@ -475,7 +471,7 @@ def test_openai_captures_multiple_generations(mock_client):
max_tokens=1,
n=2,
)
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
result = chain.invoke({}, config={"callbacks": [callbacks]})

assert result.content == "Bar"
Expand Down Expand Up @@ -530,7 +526,7 @@ def test_openai_streaming(mock_client):
chain = prompt | ChatOpenAI(
api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True
)
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
result = [m for m in chain.stream({}, config={"callbacks": [callbacks]})]
result = sum(result[1:], result[0])

Expand Down Expand Up @@ -562,7 +558,7 @@ async def test_async_openai_streaming(mock_client):
chain = prompt | ChatOpenAI(
api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True
)
callbacks = PosthogCallbackHandler(mock_client)
callbacks = CallbackHandler(mock_client)
result = [m async for m in chain.astream({}, config={"callbacks": [callbacks]})]
result = sum(result[1:], result[0])

Expand Down

0 comments on commit 739c88a

Please sign in to comment.