From ee1fd0ecfc4616c481a3b81e3e2eebf5858e5d8a Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 16 Jul 2024 21:00:01 -0700 Subject: [PATCH] fix(openai): missing span when stream response is used as context manager (#591) --- .../instrumentation/openai/_stream.py | 24 +++++++++++++++++ .../openai/test_instrumentor.py | 27 ++++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py index 39c5b9815..6db1788ee 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py @@ -110,6 +110,30 @@ async def __anext__(self) -> Any: self._process_chunk(chunk) return chunk + def __enter__(self) -> Any: + # Stream response can be used as a context manager. For example, see here + # https://github.com/langchain-ai/langchain/blob/dc42279eb55fbb8ec5175d24c7b30fe7b502b6d1/libs/partners/openai/langchain_openai/chat_models/base.py#L513 # noqa E501 + # in LangChain. When that happens, the __enter__ method on the wrapped + # object is called and the stream object escapes our wrapper. See here + # https://github.com/openai/openai-python/blob/435a5805ccbd5939a68f7f359ab72e937ef86e59/src/openai/_streaming.py#L103-L104 # noqa E501 + # We override the __enter__ method so the wrapped object does not escape. + obj = self.__wrapped__.__enter__() + if obj is self.__wrapped__: + return self + return obj + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + self.__wrapped__.__exit__(*args, **kwargs) + + async def __aenter__(self) -> Any: + obj = await self.__wrapped__.__aenter__() + if obj is self.__wrapped__: + return self + return obj + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + await self.__wrapped__.__aexit__(*args, **kwargs) + def _process_chunk(self, chunk: Any) -> None: if not self._self_iteration_count: try: diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py index a6dbeac86..6d12cd797 100644 --- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py @@ -117,8 +117,13 @@ async def task() -> None: response = await create(**create_kwargs) response = response.parse() if is_raw else response if is_stream: - async for _ in response: - pass + if _openai_version() >= (1, 6, 0): + async with response as iterator: + async for _ in iterator: + pass + else: + async for _ in response: + pass with suppress(openai.BadRequestError): if use_context_attributes: @@ -137,8 +142,13 @@ async def task() -> None: response = create(**create_kwargs) response = response.parse() if is_raw else response if is_stream: - for _ in response: - pass + if _openai_version() >= (1, 6, 0): + with response as iterator: + for _ in iterator: + pass + else: + for _ in response: + pass else: if is_async: asyncio.run(task()) @@ -146,8 +156,13 @@ async def task() -> None: response = create(**create_kwargs) response = response.parse() if is_raw else response if is_stream: - for _ in response: - pass + if _openai_version() >= (1, 6, 0): + with response as iterator: + for _ in iterator: + pass + else: + for _ in response: + pass spans = in_memory_span_exporter.get_finished_spans() assert len(spans) == 2 # first span should be from the httpx instrumentor span: ReadableSpan = spans[1]