Skip to content

Commit

Permalink
Change default names to what Granite Guardian expects by default
Browse files Browse the repository at this point in the history
Signed-off-by: Martín Santillán Cooper <[email protected]>
  • Loading branch information
martinscooper committed Feb 7, 2025
1 parent 1fa641b commit 612055e
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6043,9 +6043,9 @@ class GraniteGuardianWMLMetric(InstanceMetric):
risk_type: RiskType = None
risk_definition: Optional[str] = None

user_message_field: str = "question"
assistant_message_field: str = "answer"
contexts_field: str = "contexts"
user_message_field: str = "user"
assistant_message_field: str = "assistant"
context_field: str = "context"
tools_field: str = "tools"

available_risks: Dict[RiskType, List[str]] = {
Expand Down Expand Up @@ -6078,17 +6078,17 @@ def verify_guardian_config(self, task_data):
self.risk_type = RiskType.RAG
if self.risk_name == "context_relevance":
assert (
self.contexts_field in task_data
self.context_field in task_data
and self.user_message_field in task_data
), UnitxtError(
f'Task data must contain "{self.contexts_field}" and "{self.user_message_field}" fields'
f'Task data must contain "{self.context_field}" and "{self.user_message_field}" fields'
)
elif self.risk_name == "groundedness":
assert (
self.contexts_field in task_data
self.context_field in task_data
and self.assistant_message_field in task_data
), UnitxtError(
f'Task data must contain "{self.contexts_field}" and "{self.assistant_message_field}" fields'
f'Task data must contain "{self.context_field}" and "{self.assistant_message_field}" fields'
)
elif self.risk_name == "answer_relevance":
assert (
Expand Down Expand Up @@ -6139,9 +6139,9 @@ def verify_guardian_config(self, task_data):
self.tools_field in task_data
or self.user_message_field in task_data
or self.assistant_message_field in task_data
or self.contexts_field in task_data
or self.context_field in task_data
), UnitxtError(
f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.contexts_field}" fields'
f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields'
)

def prepare(self):
Expand Down Expand Up @@ -6222,11 +6222,11 @@ def process_input_fields(self, task_data):
"user", task_data[self.user_message_field]
)
messages += self.create_message(
"context", "\n".join(task_data[self.contexts_field])
"context", task_data[self.context_field]
)
elif self.risk_name == "groundedness":
messages += self.create_message(
"context", "\n".join(task_data[self.contexts_field])
"context", task_data[self.context_field]
)
messages += self.create_message(
"assistant", task_data[self.assistant_message_field]
Expand Down Expand Up @@ -6254,9 +6254,9 @@ def process_input_fields(self, task_data):
elif self.risk_type == RiskType.USER_MESSAGE:
messages += self.create_message("user", task_data[self.user_message_field])
elif self.risk_type == RiskType.CUSTOM_RISK:
if self.contexts_field in task_data:
if self.context_field in task_data:
messages += self.create_message(
"context", "\n".join(task_data[self.contexts_field])
"context", task_data[self.context_field]
)
if self.tools_field in task_data:
messages += self.create_message(
Expand Down

0 comments on commit 612055e

Please sign in to comment.