diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index a055569f9..b6c367fc3 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -6163,6 +6163,20 @@ def set_main_score(self): ) self.reduction_map = {"mean": [self.main_score]} + def get_prompt(self, messages): + guardian_config = {"risk_name": self.risk_name} + if self.risk_type == RiskType.CUSTOM_RISK: + guardian_config["risk_definition"] = self.risk_definition + + processed_input = self._tokenizer.apply_chat_template( + messages, + guardian_config=guardian_config, + tokenize=False, + add_generation_prompt=True, + ) + + return processed_input + def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict: from transformers import AutoTokenizer @@ -6187,19 +6201,8 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di f'Risk type is "{self.risk_type}" and risk name is "{self.risk_name}"' ) messages = self.process_input_fields(task_data) - - guardian_config = {"risk_name": self.risk_name} - if self.risk_type == RiskType.CUSTOM_RISK: - guardian_config["risk_definition"] = self.risk_definition - - processed_input = self._tokenizer.apply_chat_template( - messages, - guardian_config=guardian_config, - tokenize=False, - add_generation_prompt=True, - ) - - result = self.inference_engine.infer_log_probs([{"source": processed_input}]) + prompt = self.get_prompt(messages) + result = self.inference_engine.infer_log_probs([{"source": prompt}]) generated_tokens_list = result[0] label, prob_of_risk = self.parse_output(generated_tokens_list) score = prob_of_risk if label is not None else np.nan