Skip to content

Commit

Permalink
fix treating token as a list
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jan 22, 2025
1 parent 09e12d8 commit 9954ce8
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
8 changes: 4 additions & 4 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,16 +408,16 @@ async def handle_post_chat_completions(self, request):
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
token, is_finished = await asyncio.wait_for(
tokens, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")

finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)

if token == eos_token_id:
if tokens[-1] == eos_token_id:
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
Expand All @@ -428,7 +428,7 @@ async def handle_post_chat_completions(self, request):
tokenizer,
prompt,
request_id,
[token],
tokens,
stream,
finish_reason,
"chat.completion",
Expand Down
2 changes: 1 addition & 1 deletion exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response =await self.stub.SendTensor(request)
response = await self.stub.SendTensor(request)

if not response.tensor_data or not response.shape or not response.dtype:
return None
Expand Down
8 changes: 4 additions & 4 deletions exo/networking/grpc/node_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ syntax = "proto3";
package node_service;

service NodeService {
rpc SendPrompt (PromptRequest) returns (Empty) {}
rpc SendTensor (TensorRequest) returns (Empty) {}
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}
Expand Down Expand Up @@ -95,7 +95,7 @@ message DeviceCapabilities {
DeviceFlops flops = 4;
}

message SendNewTokenRequest {
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
optional Tensor tensor = 3;
Expand Down
26 changes: 13 additions & 13 deletions exo/networking/grpc/node_service_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 16 additions & 16 deletions exo/networking/grpc/node_service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def __init__(self, channel):
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
Expand All @@ -54,9 +54,9 @@ def __init__(self, channel):
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendNewToken = channel.unary_unary(
'/node_service.NodeService/SendNewToken',
request_serializer=node__service__pb2.SendNewTokenRequest.SerializeToString,
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
Expand Down Expand Up @@ -98,7 +98,7 @@ def CollectTopology(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendNewToken(self, request, context):
def SendResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand All @@ -122,12 +122,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
Expand All @@ -139,9 +139,9 @@ def add_NodeServiceServicer_to_server(servicer, server):
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendNewToken': grpc.unary_unary_rpc_method_handler(
servicer.SendNewToken,
request_deserializer=node__service__pb2.SendNewTokenRequest.FromString,
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -181,7 +181,7 @@ def SendPrompt(request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
Expand All @@ -208,7 +208,7 @@ def SendTensor(request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
Expand Down Expand Up @@ -274,7 +274,7 @@ def CollectTopology(request,
_registered_method=True)

@staticmethod
def SendNewToken(request,
def SendResult(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -287,8 +287,8 @@ def SendNewToken(request,
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendNewToken',
node__service__pb2.SendNewTokenRequest.SerializeToString,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
Expand Down

0 comments on commit 9954ce8

Please sign in to comment.