Skip to content

Commit

Permalink
Modularize getting the prompt
Browse files Browse the repository at this point in the history
Signed-off-by: Martín Santillán Cooper <[email protected]>
  • Loading branch information
martinscooper committed Feb 7, 2025
1 parent 0c8f9fa commit 1fa641b
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6163,6 +6163,20 @@ def set_main_score(self):
)
self.reduction_map = {"mean": [self.main_score]}

def get_prompt(self, messages):
guardian_config = {"risk_name": self.risk_name}
if self.risk_type == RiskType.CUSTOM_RISK:
guardian_config["risk_definition"] = self.risk_definition

processed_input = self._tokenizer.apply_chat_template(
messages,
guardian_config=guardian_config,
tokenize=False,
add_generation_prompt=True,
)

return processed_input

def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
from transformers import AutoTokenizer

Expand All @@ -6187,19 +6201,8 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di
f'Risk type is "{self.risk_type}" and risk name is "{self.risk_name}"'
)
messages = self.process_input_fields(task_data)

guardian_config = {"risk_name": self.risk_name}
if self.risk_type == RiskType.CUSTOM_RISK:
guardian_config["risk_definition"] = self.risk_definition

processed_input = self._tokenizer.apply_chat_template(
messages,
guardian_config=guardian_config,
tokenize=False,
add_generation_prompt=True,
)

result = self.inference_engine.infer_log_probs([{"source": processed_input}])
prompt = self.get_prompt(messages)
result = self.inference_engine.infer_log_probs([{"source": prompt}])
generated_tokens_list = result[0]
label, prob_of_risk = self.parse_output(generated_tokens_list)
score = prob_of_risk if label is not None else np.nan
Expand Down

0 comments on commit 1fa641b

Please sign in to comment.