Skip to content

Commit

Permalink
Fix copilot errors which cause client to hangup during FIM
Browse files Browse the repository at this point in the history
  • Loading branch information
ptelang committed Jan 7, 2025
1 parent eb5cd04 commit db4d40b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
25 changes: 21 additions & 4 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def __init__(self, proxy: CopilotProvider):
def connection_made(self, transport: asyncio.Transport) -> None:
"""Handle successful connection to target"""
self.transport = transport
logger.debug(f"Target transport peer: {transport.get_extra_info('peername')}")
self.proxy.target_transport = transport

def _ensure_output_processor(self) -> None:
Expand Down Expand Up @@ -703,9 +704,10 @@ async def stream_iterator():
streaming_choices.append(
StreamingChoices(
finish_reason=choice.get("finish_reason", None),
index=0,
index=choice.get("index", 0),
delta=Delta(content=content, role="assistant"),
logprobs=None,
logprobs=choice.get("logprobs", None),
p=choice.get("p", None),
)
)

Expand All @@ -716,12 +718,13 @@ async def stream_iterator():
created=record_content.get("created", 0),
model=record_content.get("model", ""),
object="chat.completion.chunk",
stream=True,
)
yield mr

async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
sse_data = f"data:{chunk}\n\n".encode("utf-8")
sse_data = f"data: {chunk}\n\n".encode("utf-8")
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
Expand Down Expand Up @@ -764,6 +767,10 @@ def _proxy_transport_write(self, data: bytes):
logger.error("Proxy transport not available")
return
self.proxy.transport.write(data)
# print("DEBUG =================================")
# print(data)
# print("DEBUG =================================")


def data_received(self, data: bytes) -> None:
"""Handle data received from target"""
Expand All @@ -781,11 +788,21 @@ def data_received(self, data: bytes) -> None:
if header_end != -1:
self.headers_sent = True
# Send headers first
headers = data[: header_end + 4]
headers = data[: header_end]

# If Transfer-Encoding is not present, add it
if b"Transfer-Encoding:" not in headers:
headers = headers + b"\r\nTransfer-Encoding: chunked"

headers = headers + b"\r\n\r\n"

self._proxy_transport_write(headers)
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4 :]
# print("DEBUG =================================")
# print(data)
# print("DEBUG =================================")

self._process_chunk(data)

Expand Down
2 changes: 1 addition & 1 deletion src/codegate/providers/copilot/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def process_chunk(self, chunk: bytes) -> list:
data = json.loads(data_content)
records.append({"type": "data", "content": data})
except json.JSONDecodeError:
print(f"Failed to parse JSON: {data_content}")
logger.debug(f"Failed to parse JSON: {data_content}")

return records

Expand Down

0 comments on commit db4d40b

Please sign in to comment.