diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index b6c367fc3..daafef12f 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -6043,9 +6043,9 @@ class GraniteGuardianWMLMetric(InstanceMetric): risk_type: RiskType = None risk_definition: Optional[str] = None - user_message_field: str = "question" - assistant_message_field: str = "answer" - contexts_field: str = "contexts" + user_message_field: str = "user" + assistant_message_field: str = "assistant" + context_field: str = "context" tools_field: str = "tools" available_risks: Dict[RiskType, List[str]] = { @@ -6078,17 +6078,17 @@ def verify_guardian_config(self, task_data): self.risk_type = RiskType.RAG if self.risk_name == "context_relevance": assert ( - self.contexts_field in task_data + self.context_field in task_data and self.user_message_field in task_data ), UnitxtError( - f'Task data must contain "{self.contexts_field}" and "{self.user_message_field}" fields' + f'Task data must contain "{self.context_field}" and "{self.user_message_field}" fields' ) elif self.risk_name == "groundedness": assert ( - self.contexts_field in task_data + self.context_field in task_data and self.assistant_message_field in task_data ), UnitxtError( - f'Task data must contain "{self.contexts_field}" and "{self.assistant_message_field}" fields' + f'Task data must contain "{self.context_field}" and "{self.assistant_message_field}" fields' ) elif self.risk_name == "answer_relevance": assert ( @@ -6139,9 +6139,9 @@ def verify_guardian_config(self, task_data): self.tools_field in task_data or self.user_message_field in task_data or self.assistant_message_field in task_data - or self.contexts_field in task_data + or self.context_field in task_data ), UnitxtError( - f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.contexts_field}" fields' + f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields' ) def prepare(self): @@ -6222,11 +6222,11 @@ def process_input_fields(self, task_data): "user", task_data[self.user_message_field] ) messages += self.create_message( - "context", "\n".join(task_data[self.contexts_field]) + "context", task_data[self.context_field] ) elif self.risk_name == "groundedness": messages += self.create_message( - "context", "\n".join(task_data[self.contexts_field]) + "context", task_data[self.context_field] ) messages += self.create_message( "assistant", task_data[self.assistant_message_field] @@ -6254,9 +6254,9 @@ def process_input_fields(self, task_data): elif self.risk_type == RiskType.USER_MESSAGE: messages += self.create_message("user", task_data[self.user_message_field]) elif self.risk_type == RiskType.CUSTOM_RISK: - if self.contexts_field in task_data: + if self.context_field in task_data: messages += self.create_message( - "context", "\n".join(task_data[self.contexts_field]) + "context", task_data[self.context_field] ) if self.tools_field in task_data: messages += self.create_message(