diff --git a/examples/evaluate_granite_guardian_assistant_message_risks.py b/examples/evaluate_granite_guardian_assistant_message_risks.py index 953b9c7ee..32dc38e00 100644 --- a/examples/evaluate_granite_guardian_assistant_message_risks.py +++ b/examples/evaluate_granite_guardian_assistant_message_risks.py @@ -3,7 +3,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianWMLMetric, RiskType +from unitxt.metrics import GraniteGuardianAssistantRisk, RiskType from unitxt.templates import NullTemplate print("Assistant response risks") @@ -40,9 +40,7 @@ risks = [ f"metrics.granite_guardian.assistant_risk.{assistant_risk}" - for assistant_risk in GraniteGuardianWMLMetric.available_risks[ - RiskType.ASSISTANT_MESSAGE - ] + for assistant_risk in GraniteGuardianAssistantRisk.get_available_risk_names() ] print( diff --git a/examples/evaluate_granite_guardian_custom_risks.py b/examples/evaluate_granite_guardian_custom_risks.py index a44d9f133..dea381b0c 100644 --- a/examples/evaluate_granite_guardian_custom_risks.py +++ b/examples/evaluate_granite_guardian_custom_risks.py @@ -1,7 +1,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianWMLMetric, RiskType +from unitxt.metrics import GraniteGuardianCustomRisk, RiskType from unitxt.templates import NullTemplate print("Bring your own risk") @@ -20,10 +20,9 @@ default_template=NullTemplate(), # metrics=["metrics.granite_guardian.agentic.function_call[user_message_field=user_prompt, risk_definition=User message contains personal information or sensitive personal information that is included as a part of a prompt.]"], metrics=[ - GraniteGuardianWMLMetric( + GraniteGuardianCustomRisk( risk_name="personal_information", risk_definition="User message contains personal information or sensitive personal information that is included as a part of a prompt.", - risk_type=RiskType.RAG.name, user_message_field="user_prompt", ) ], diff --git a/examples/evaluate_granite_guardian_user_message_risks.py b/examples/evaluate_granite_guardian_user_message_risks.py index 14154e8d8..1a3e901fc 100644 --- a/examples/evaluate_granite_guardian_user_message_risks.py +++ b/examples/evaluate_granite_guardian_user_message_risks.py @@ -3,7 +3,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianMetric, RiskType +from unitxt.metrics import GraniteGuardianUserRisk, RiskType from unitxt.templates import NullTemplate print("User prompt risks") @@ -23,7 +23,7 @@ risks = [ f"metrics.granite_guardian.user_risk.{user_risk}" - for user_risk in GraniteGuardianMetric.available_risks[RiskType.USER_MESSAGE] + for user_risk in GraniteGuardianUserRisk.get_available_risk_names() ] print( diff --git a/prepare/metrics/granite_guardian.py b/prepare/metrics/granite_guardian.py index 33e7b9f74..d7d5f1001 100644 --- a/prepare/metrics/granite_guardian.py +++ b/prepare/metrics/granite_guardian.py @@ -1,8 +1,8 @@ from unitxt import add_to_catalog -from unitxt.metrics import GraniteGuardianMetric +from unitxt.metrics import GraniteGuardianBase, RISK_TYPE_TO_CLASS -for risk_type, risk_names in GraniteGuardianMetric.available_risks.items(): +for risk_type, risk_names in GraniteGuardianBase.available_risks.items(): for risk_name in risk_names: metric_name = f"""granite_guardian.{risk_type.value}.{risk_name}""" - metric = GraniteGuardianMetric(risk_name=risk_name, risk_type=risk_type.name) + metric = RISK_TYPE_TO_CLASS[risk_type](risk_name=risk_name) add_to_catalog(metric, name=f"metrics.{metric_name}", overwrite=True) diff --git a/prepare/metrics/rag_granite_guardian.py b/prepare/metrics/rag_granite_guardian.py index 51dfe50f7..9d0646550 100644 --- a/prepare/metrics/rag_granite_guardian.py +++ b/prepare/metrics/rag_granite_guardian.py @@ -1,5 +1,5 @@ from unitxt import add_to_catalog -from unitxt.metrics import GraniteGuardianMetric, MetricPipeline +from unitxt.metrics import GraniteGuardianRagRisk, MetricPipeline from unitxt.operators import Copy, Set from unitxt.string_operators import Join @@ -22,7 +22,7 @@ for granite_risk_name, pred_field in risk_names_to_pred_field.items(): metric_name = f"""granite_guardian_{granite_risk_name}""" - metric = GraniteGuardianMetric( + metric = GraniteGuardianRagRisk( main_score=metric_name, risk_name=granite_risk_name, user_message_field="question", diff --git a/src/unitxt/catalog/metrics/granite_guardian/agentic_risk/function_call.json b/src/unitxt/catalog/metrics/granite_guardian/agentic_risk/function_call.json index 56a01866d..c4c95e51b 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/agentic_risk/function_call.json +++ b/src/unitxt/catalog/metrics/granite_guardian/agentic_risk/function_call.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "function_call", - "risk_type": "AGENTIC" + "__type__": "granite_guardian_agentic_risk", + "risk_name": "function_call" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/harm.json b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/harm.json index b3f180782..98b726aff 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/harm.json +++ b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/harm.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "harm", - "risk_type": "ASSISTANT_MESSAGE" + "__type__": "granite_guardian_assistant_risk", + "risk_name": "harm" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/profanity.json b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/profanity.json index ee32e1b5d..38a25ea59 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/profanity.json +++ b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/profanity.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "profanity", - "risk_type": "ASSISTANT_MESSAGE" + "__type__": "granite_guardian_assistant_risk", + "risk_name": "profanity" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/social_bias.json b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/social_bias.json index 19856bf5a..89a17c66f 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/social_bias.json +++ b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/social_bias.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "social_bias", - "risk_type": "ASSISTANT_MESSAGE" + "__type__": "granite_guardian_assistant_risk", + "risk_name": "social_bias" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/unethical_behavior.json b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/unethical_behavior.json index 598ba312f..5e4b6b0cc 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/unethical_behavior.json +++ b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/unethical_behavior.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "unethical_behavior", - "risk_type": "ASSISTANT_MESSAGE" + "__type__": "granite_guardian_assistant_risk", + "risk_name": "unethical_behavior" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/violence.json b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/violence.json index 4a0847887..1a62aa18c 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/violence.json +++ b/src/unitxt/catalog/metrics/granite_guardian/assistant_risk/violence.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "violence", - "risk_type": "ASSISTANT_MESSAGE" + "__type__": "granite_guardian_assistant_risk", + "risk_name": "violence" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/answer_relevance.json b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/answer_relevance.json index 27641b6b3..c3eed9a23 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/answer_relevance.json +++ b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/answer_relevance.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "answer_relevance", - "risk_type": "RAG" + "__type__": "granite_guardian_rag_risk", + "risk_name": "answer_relevance" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/context_relevance.json b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/context_relevance.json index e229b0b4c..1b6868445 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/context_relevance.json +++ b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/context_relevance.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "context_relevance", - "risk_type": "RAG" + "__type__": "granite_guardian_rag_risk", + "risk_name": "context_relevance" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/groundedness.json b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/groundedness.json index 9efb58823..67f45ff85 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/rag_risk/groundedness.json +++ b/src/unitxt/catalog/metrics/granite_guardian/rag_risk/groundedness.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "groundedness", - "risk_type": "RAG" + "__type__": "granite_guardian_rag_risk", + "risk_name": "groundedness" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/harm.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/harm.json index bbd48a406..f991c87c7 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/harm.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/harm.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "harm", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "harm" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/jailbreak.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/jailbreak.json index 0a3ce9246..d59ea1601 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/jailbreak.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/jailbreak.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "jailbreak", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "jailbreak" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/profanity.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/profanity.json index 4156e55fd..01ffa67a5 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/profanity.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/profanity.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "profanity", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "profanity" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/social_bias.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/social_bias.json index 0067c7faa..f1e3f4b44 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/social_bias.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/social_bias.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "social_bias", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "social_bias" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/unethical_behavior.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/unethical_behavior.json index 238e7a51c..18c4eaffc 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/unethical_behavior.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/unethical_behavior.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "unethical_behavior", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "unethical_behavior" } diff --git a/src/unitxt/catalog/metrics/granite_guardian/user_risk/violence.json b/src/unitxt/catalog/metrics/granite_guardian/user_risk/violence.json index c8b0d90b9..0421eab6a 100644 --- a/src/unitxt/catalog/metrics/granite_guardian/user_risk/violence.json +++ b/src/unitxt/catalog/metrics/granite_guardian/user_risk/violence.json @@ -1,5 +1,4 @@ { - "__type__": "granite_guardian_metric", - "risk_name": "violence", - "risk_type": "USER_MESSAGE" + "__type__": "granite_guardian_user_risk", + "risk_name": "violence" } diff --git a/src/unitxt/catalog/metrics/rag/granite_guardian_answer_relevance.json b/src/unitxt/catalog/metrics/rag/granite_guardian_answer_relevance.json index 7e4595af7..349f8995f 100644 --- a/src/unitxt/catalog/metrics/rag/granite_guardian_answer_relevance.json +++ b/src/unitxt/catalog/metrics/rag/granite_guardian_answer_relevance.json @@ -2,7 +2,7 @@ "__type__": "metric_pipeline", "main_score": "granite_guardian_answer_relevance", "metric": { - "__type__": "granite_guardian_metric", + "__type__": "granite_guardian_rag_risk", "main_score": "granite_guardian_answer_relevance", "risk_name": "answer_relevance", "user_message_field": "question", diff --git a/src/unitxt/catalog/metrics/rag/granite_guardian_context_relevance.json b/src/unitxt/catalog/metrics/rag/granite_guardian_context_relevance.json index 5d66095d4..921131509 100644 --- a/src/unitxt/catalog/metrics/rag/granite_guardian_context_relevance.json +++ b/src/unitxt/catalog/metrics/rag/granite_guardian_context_relevance.json @@ -2,7 +2,7 @@ "__type__": "metric_pipeline", "main_score": "granite_guardian_context_relevance", "metric": { - "__type__": "granite_guardian_metric", + "__type__": "granite_guardian_rag_risk", "main_score": "granite_guardian_context_relevance", "risk_name": "context_relevance", "user_message_field": "question", diff --git a/src/unitxt/catalog/metrics/rag/granite_guardian_groundedness.json b/src/unitxt/catalog/metrics/rag/granite_guardian_groundedness.json index fd108ccd5..35c74ff6d 100644 --- a/src/unitxt/catalog/metrics/rag/granite_guardian_groundedness.json +++ b/src/unitxt/catalog/metrics/rag/granite_guardian_groundedness.json @@ -2,7 +2,7 @@ "__type__": "metric_pipeline", "main_score": "granite_guardian_groundedness", "metric": { - "__type__": "granite_guardian_metric", + "__type__": "granite_guardian_rag_risk", "main_score": "granite_guardian_groundedness", "risk_name": "groundedness", "user_message_field": "question", diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 4d58079f8..0eaf8faf3 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -5860,8 +5860,7 @@ class RiskType(str, Enum): AGENTIC = "agentic_risk" CUSTOM_RISK = "custom_risk" - -class GraniteGuardianMetric(InstanceMetric): +class GraniteGuardianBase(InstanceMetric): """Return metric for different kinds of "risk" from the Granite-3.0 Guardian model.""" reduction_map: Dict[str, List[str]] = None @@ -5919,84 +5918,28 @@ class GraniteGuardianMetric(InstanceMetric): _requirements_list: List[str] = ["torch", "transformers"] - def verify_guardian_config(self, task_data): - if ( - self.risk_name == RiskType.RAG - or self.risk_name in self.available_risks[RiskType.RAG] - ): - self.risk_type = RiskType.RAG - if self.risk_name == "context_relevance": - assert ( - self.context_field in task_data - and self.user_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.context_field}" and "{self.user_message_field}" fields' - ) - elif self.risk_name == "groundedness": - assert ( - self.context_field in task_data - and self.assistant_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.context_field}" and "{self.assistant_message_field}" fields' - ) - elif self.risk_name == "answer_relevance": - assert ( - self.user_message_field in task_data - and self.assistant_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.user_message_field}" and "{self.assistant_message_field}" fields' - ) - elif self.risk_name == RiskType.USER_MESSAGE or ( - self.risk_name in self.available_risks[RiskType.USER_MESSAGE] - and self.assistant_message_field not in task_data - ): - # User message risks only require the user message field and are the same as the assistant message risks, except for jailbreak - self.risk_type = RiskType.USER_MESSAGE - assert self.user_message_field in task_data, UnitxtError( - f'Task data must contain "{self.user_message_field}" field' - ) - elif ( - self.risk_name == RiskType.ASSISTANT_MESSAGE - or self.risk_name in self.available_risks[RiskType.ASSISTANT_MESSAGE] - ): - self.risk_type = RiskType.ASSISTANT_MESSAGE - assert ( - self.assistant_message_field in task_data - and self.user_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.assistant_message_field}" and "{self.user_message_field}" fields' - ) - elif ( - self.risk_name == RiskType.AGENTIC - or self.risk_name in self.available_risks[RiskType.AGENTIC] - ): - self.risk_type = RiskType.AGENTIC - assert ( - self.tools_field in task_data - and self.user_message_field in task_data - and self.assistant_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.tools_field}", "{self.assistant_message_field}" and "{self.user_message_field}" fields' - ) - else: - # even though this is a custom risks, we will limit the - # message roles to be a subset of the roles Granite Guardian - # was trained with: user, assistant, context & tools. - # we just checked whether at least one of them is provided - self.risk_type = RiskType.CUSTOM_RISK - assert ( - self.tools_field in task_data - or self.user_message_field in task_data - or self.assistant_message_field in task_data - or self.context_field in task_data - ), UnitxtError( - f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields' - ) - def prepare(self): - if isinstance(self.risk_type, str): + if not isinstance(self.risk_type, RiskType): self.risk_type = RiskType[self.risk_type] + def verify(self): + super().verify() + assert self.risk_type == RiskType.CUSTOM_RISK or self.risk_name in self.available_risks[self.risk_type], UnitxtError(f"The risk \'{self.risk_name}\' is not a valid \'{' '.join([word[0].upper() + word[1:] for word in self.risk_type.split('_')])}\'") + + @abstractmethod + def verify_granite_guardian_config(self, task_data): + pass + + @abstractmethod + def process_input_fields(self, task_data): + pass + + @classmethod + def get_available_risk_names(cls): + print(cls.risk_type) + print(cls.available_risks) + return cls.available_risks[cls.risk_type] + def set_main_score(self): self.main_score = self.risk_name self.reduction_map = {"mean": [self.main_score]} @@ -6016,7 +5959,7 @@ def get_prompt(self, messages): def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict: from transformers import AutoTokenizer - self.verify_guardian_config(task_data) + self.verify_granite_guardian_config(task_data) self.set_main_score() if not hasattr(self, "_tokenizer") or self._tokenizer is None: self._tokenizer = AutoTokenizer.from_pretrained(self.hf_model_name) @@ -6031,6 +5974,7 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di messages = self.process_input_fields(task_data) prompt = self.get_prompt(messages) result = self.inference_engine.infer_log_probs([{"source": prompt}]) + print(' '.join([r['text'] for r in result[0]])) generated_tokens_list = result[0] label, prob_of_risk = self.parse_output(generated_tokens_list) confidence_score = ( @@ -6050,68 +5994,6 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di def create_message(self, role: str, content: str) -> List[Dict[str, str]]: return [{"role": role, "content": content}] - def process_input_fields(self, task_data): - messages = [] - logger.debug("Preparing messages for Granite Guardian.") - if self.risk_type == RiskType.RAG: - if self.risk_name == "context_relevance": - messages += self.create_message( - "user", task_data[self.user_message_field] - ) - messages += self.create_message( - "context", task_data[self.context_field] - ) - elif self.risk_name == "groundedness": - messages += self.create_message( - "context", task_data[self.context_field] - ) - messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) - elif self.risk_name == "answer_relevance": - messages += self.create_message( - "user", task_data[self.user_message_field] - ) - messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) - elif self.risk_type == RiskType.AGENTIC: - messages += self.create_message( - "tools", json.loads(task_data[self.tools_field]) - ) - messages += self.create_message("user", task_data[self.user_message_field]) - messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) - elif self.risk_type == RiskType.ASSISTANT_MESSAGE: - messages += self.create_message("user", task_data[self.user_message_field]) - messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) - elif self.risk_type == RiskType.USER_MESSAGE: - messages += self.create_message("user", task_data[self.user_message_field]) - elif self.risk_type == RiskType.CUSTOM_RISK: - if self.context_field in task_data: - messages += self.create_message( - "context", task_data[self.context_field] - ) - if self.tools_field in task_data: - messages += self.create_message( - "tools", json.loads(task_data[self.tools_field]) - ) - if self.user_message_field in task_data: - messages += self.create_message( - "user", task_data[self.user_message_field] - ) - if self.assistant_message_field in task_data: - messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) - else: - raise NotImplementedError("Something went wrong generating the messages") - logger.debug(f"Input messages are:\n{messages}") - return messages - def parse_output(self, generated_tokens_list): top_tokens_list = [ generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list @@ -6148,6 +6030,156 @@ def get_probabilities(self, top_tokens_list): dim=0, ).numpy() +class GraniteGuardianUserRisk(GraniteGuardianBase): + risk_type = RiskType.USER_MESSAGE + def verify_granite_guardian_config(self, task_data): + # User message risks only require the user message field and are the same as the assistant message risks, except for jailbreak + assert self.user_message_field in task_data, UnitxtError( + f'Task data must contain "{self.user_message_field}" field' + ) + + def process_input_fields(self, task_data): + messages = [] + messages += self.create_message("user", task_data[self.user_message_field]) + return messages + +class GraniteGuardianAssistantRisk(GraniteGuardianBase): + risk_type = RiskType.ASSISTANT_MESSAGE + def verify_granite_guardian_config(self, task_data): + assert ( + self.assistant_message_field in task_data + and self.user_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.assistant_message_field}" and "{self.user_message_field}" fields' + ) + + def process_input_fields(self, task_data): + messages = [] + messages += self.create_message("user", task_data[self.user_message_field]) + messages += self.create_message( + "assistant", task_data[self.assistant_message_field] + ) + return messages + +class GraniteGuardianRagRisk(GraniteGuardianBase): + risk_type = RiskType.RAG + + def verify_granite_guardian_config(self, task_data): + if self.risk_name == "context_relevance": + assert ( + self.context_field in task_data + and self.user_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.context_field}" and "{self.user_message_field}" fields' + ) + elif self.risk_name == "groundedness": + assert ( + self.context_field in task_data + and self.assistant_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.context_field}" and "{self.assistant_message_field}" fields' + ) + elif self.risk_name == "answer_relevance": + assert ( + self.user_message_field in task_data + and self.assistant_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.user_message_field}" and "{self.assistant_message_field}" fields' + ) + + def process_input_fields(self, task_data): + messages = [] + if self.risk_name == "context_relevance": + messages += self.create_message( + "user", task_data[self.user_message_field] + ) + messages += self.create_message( + "context", task_data[self.context_field] + ) + elif self.risk_name == "groundedness": + messages += self.create_message( + "context", task_data[self.context_field] + ) + messages += self.create_message( + "assistant", task_data[self.assistant_message_field] + ) + elif self.risk_name == "answer_relevance": + messages += self.create_message( + "user", task_data[self.user_message_field] + ) + messages += self.create_message( + "assistant", task_data[self.assistant_message_field] + ) + return messages +class GraniteGuardianAgenticRisk(GraniteGuardianBase): + risk_type = RiskType.AGENTIC + def verify_granite_guardian_config(self, task_data): + assert ( + self.tools_field in task_data + and self.user_message_field in task_data + and self.assistant_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.tools_field}", "{self.assistant_message_field}" and "{self.user_message_field}" fields' + ) + + def process_input_fields(self, task_data): + messages = [] + messages += self.create_message( + "tools", json.loads(task_data[self.tools_field]) + ) + messages += self.create_message("user", task_data[self.user_message_field]) + messages += self.create_message( + "assistant", task_data[self.assistant_message_field] + ) + return messages + +class GraniteGuardianCustomRisk(GraniteGuardianBase): + risk_type = RiskType.CUSTOM_RISK + + def verify(self): + super().verify() + assert self.risk_type != None, UnitxtError("In a custom risk, risk_type must be defined") + + def verify_granite_guardian_config(self, task_data): + # even though this is a custom risks, we will limit the + # message roles to be a subset of the roles Granite Guardian + # was trained with: user, assistant, context & tools. + # we just checked whether at least one of them is provided + assert ( + self.tools_field in task_data + or self.user_message_field in task_data + or self.assistant_message_field in task_data + or self.context_field in task_data + ), UnitxtError( + f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields' + ) + + def process_input_fields(self, task_data): + messages = [] + if self.context_field in task_data: + messages += self.create_message( + "context", task_data[self.context_field] + ) + if self.tools_field in task_data: + messages += self.create_message( + "tools", json.loads(task_data[self.tools_field]) + ) + if self.user_message_field in task_data: + messages += self.create_message( + "user", task_data[self.user_message_field] + ) + if self.assistant_message_field in task_data: + messages += self.create_message( + "assistant", task_data[self.assistant_message_field] + ) + return messages + +RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = { + RiskType.USER_MESSAGE: GraniteGuardianUserRisk, + RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk, + RiskType.RAG: GraniteGuardianRagRisk, + RiskType.AGENTIC: GraniteGuardianAgenticRisk, +} class ExecutionAccuracy(InstanceMetric): reduction_map = {"mean": ["execution_accuracy"]}