From 1fa641b380534ee1616060d93f744085c584a28c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Santill=C3=A1n=20Cooper?= Date: Fri, 7 Feb 2025 10:18:16 -0300 Subject: [PATCH] Modularize getting the prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martín Santillán Cooper --- src/unitxt/metrics.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index a055569f9..b6c367fc3 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -6163,6 +6163,20 @@ def set_main_score(self): ) self.reduction_map = {"mean": [self.main_score]} + def get_prompt(self, messages): + guardian_config = {"risk_name": self.risk_name} + if self.risk_type == RiskType.CUSTOM_RISK: + guardian_config["risk_definition"] = self.risk_definition + + processed_input = self._tokenizer.apply_chat_template( + messages, + guardian_config=guardian_config, + tokenize=False, + add_generation_prompt=True, + ) + + return processed_input + def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict: from transformers import AutoTokenizer @@ -6187,19 +6201,8 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di f'Risk type is "{self.risk_type}" and risk name is "{self.risk_name}"' ) messages = self.process_input_fields(task_data) - - guardian_config = {"risk_name": self.risk_name} - if self.risk_type == RiskType.CUSTOM_RISK: - guardian_config["risk_definition"] = self.risk_definition - - processed_input = self._tokenizer.apply_chat_template( - messages, - guardian_config=guardian_config, - tokenize=False, - add_generation_prompt=True, - ) - - result = self.inference_engine.infer_log_probs([{"source": processed_input}]) + prompt = self.get_prompt(messages) + result = self.inference_engine.infer_log_probs([{"source": prompt}]) generated_tokens_list = result[0] label, prob_of_risk = self.parse_output(generated_tokens_list) score = prob_of_risk if label is not None else np.nan