Skip to content

Commit

Permalink
Add parallelization to RITS inference (#1441)
Browse files Browse the repository at this point in the history
Signed-off-by: Ariel Gera <[email protected]>
  • Loading branch information
arielge authored Jan 5, 2025
1 parent 511e278 commit 63c5310
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 36 deletions.
99 changes: 66 additions & 33 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import uuid
from collections import Counter
from multiprocessing.pool import ThreadPool
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -1469,6 +1470,13 @@ class OpenAiInferenceEngineParams(Artifact):
service_tier: Optional[Literal["auto", "default"]] = None


def run_with_imap(func):
def inner(self, args):
return func(self, *args)

return inner


class OpenAiInferenceEngine(
InferenceEngine,
LogProbInferenceEngine,
Expand All @@ -1485,6 +1493,7 @@ class OpenAiInferenceEngine(
base_url: Optional[str] = None
default_headers: Dict[str, str] = {}
credentials: CredentialsOpenAi = {}
num_parallel_requests: int = 20

def get_engine_id(self) -> str:
return get_model_and_label_id(self.model_name, self.label)
Expand Down Expand Up @@ -1528,52 +1537,76 @@ def _get_completion_kwargs(self):
if v is not None
}

def _infer(
def _parallel_infer(
self,
dataset: Union[List[Dict[str, Any]], Dataset],
infer_func,
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
inputs = [(instance, return_meta_data) for instance in dataset]
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
messages = self.to_messages(instance)
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
**self._get_completion_kwargs(),
)
prediction = response.choices[0].message.content
output = self.get_return_object(prediction, response, return_meta_data)

outputs.append(output)
with ThreadPool(processes=self.num_parallel_requests) as pool:
for output in tqdm(
pool.imap(infer_func, inputs),
total=len(inputs),
desc=f"Inferring with {self.__class__.__name__}",
):
outputs.append(output)

return outputs

def _infer(
self,
dataset: Union[List[Dict[str, Any]], Dataset],
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
return self._parallel_infer(
dataset=dataset,
return_meta_data=return_meta_data,
infer_func=self._get_chat_completion,
)

def _infer_log_probs(
self,
dataset: Union[List[Dict[str, Any]], Dataset],
return_meta_data: bool = False,
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
messages = self.to_messages(instance)
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
**self._get_completion_kwargs(),
)
top_logprobs_response = response.choices[0].logprobs.content
pred_output = [
{
"top_tokens": [
{"text": obj.token, "logprob": obj.logprob}
for obj in generated_token.top_logprobs
]
}
for generated_token in top_logprobs_response
]
output = self.get_return_object(pred_output, response, return_meta_data)
outputs.append(output)
return outputs
return self._parallel_infer(
dataset=dataset,
return_meta_data=return_meta_data,
infer_func=self._get_logprobs,
)

@run_with_imap
def _get_chat_completion(self, instance, return_meta_data):
messages = self.to_messages(instance)
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
**self._get_completion_kwargs(),
)
prediction = response.choices[0].message.content
return self.get_return_object(prediction, response, return_meta_data)

@run_with_imap
def _get_logprobs(self, instance, return_meta_data):
messages = self.to_messages(instance)
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
**self._get_completion_kwargs(),
)
top_logprobs_response = response.choices[0].logprobs.content
pred_output = [
{
"top_tokens": [
{"text": obj.token, "logprob": obj.logprob}
for obj in generated_token.top_logprobs
]
}
for generated_token in top_logprobs_response
]
return self.get_return_object(pred_output, response, return_meta_data)

def get_return_object(self, predict_result, response, return_meta_data):
if return_meta_data:
Expand Down
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"filename": "src/unitxt/inference.py",
"hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729",
"is_verified": false,
"line_number": 1249,
"line_number": 1250,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "src/unitxt/inference.py",
"hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79",
"is_verified": false,
"line_number": 1663,
"line_number": 1696,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-24T18:00:14Z"
"generated_at": "2025-01-05T08:58:59Z"
}

0 comments on commit 63c5310

Please sign in to comment.