Skip to content

Commit

Permalink
fix(openai): missing span when stream response is used as context man…
Browse files Browse the repository at this point in the history
…ager (#591)
  • Loading branch information
RogerHYang authored Jul 17, 2024
1 parent 1d521ea commit ee1fd0e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ async def __anext__(self) -> Any:
self._process_chunk(chunk)
return chunk

def __enter__(self) -> Any:
# Stream response can be used as a context manager. For example, see here
# https://github.com/langchain-ai/langchain/blob/dc42279eb55fbb8ec5175d24c7b30fe7b502b6d1/libs/partners/openai/langchain_openai/chat_models/base.py#L513 # noqa E501
# in LangChain. When that happens, the __enter__ method on the wrapped
# object is called and the stream object escapes our wrapper. See here
# https://github.com/openai/openai-python/blob/435a5805ccbd5939a68f7f359ab72e937ef86e59/src/openai/_streaming.py#L103-L104 # noqa E501
# We override the __enter__ method so the wrapped object does not escape.
obj = self.__wrapped__.__enter__()
if obj is self.__wrapped__:
return self
return obj

def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.__wrapped__.__exit__(*args, **kwargs)

async def __aenter__(self) -> Any:
obj = await self.__wrapped__.__aenter__()
if obj is self.__wrapped__:
return self
return obj

async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
await self.__wrapped__.__aexit__(*args, **kwargs)

def _process_chunk(self, chunk: Any) -> None:
if not self._self_iteration_count:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ async def task() -> None:
response = await create(**create_kwargs)
response = response.parse() if is_raw else response
if is_stream:
async for _ in response:
pass
if _openai_version() >= (1, 6, 0):
async with response as iterator:
async for _ in iterator:
pass
else:
async for _ in response:
pass

with suppress(openai.BadRequestError):
if use_context_attributes:
Expand All @@ -137,17 +142,27 @@ async def task() -> None:
response = create(**create_kwargs)
response = response.parse() if is_raw else response
if is_stream:
for _ in response:
pass
if _openai_version() >= (1, 6, 0):
with response as iterator:
for _ in iterator:
pass
else:
for _ in response:
pass
else:
if is_async:
asyncio.run(task())
else:
response = create(**create_kwargs)
response = response.parse() if is_raw else response
if is_stream:
for _ in response:
pass
if _openai_version() >= (1, 6, 0):
with response as iterator:
for _ in iterator:
pass
else:
for _ in response:
pass
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 2 # first span should be from the httpx instrumentor
span: ReadableSpan = spans[1]
Expand Down

0 comments on commit ee1fd0e

Please sign in to comment.