Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and adjustment in rag metrics and related inference engines #1466

Merged
merged 33 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
becb670
add new classification engines and remove deleted llama-3-70b-instruc…
lilacheden Dec 16, 2024
5eff0b4
allow proprietary data on RITSInferenceEngine
lilacheden Dec 16, 2024
603b2d6
propagate score_prefix from metricPipeline to its metric
lilacheden Dec 23, 2024
1177e5e
Adjust autorag metrics to unitxt flow with rag.response_generation task
lilacheden Dec 23, 2024
29dba3d
Adjust granite_guardian to unitxt flow
lilacheden Dec 23, 2024
686df9e
Add AzureOpenAIInferenceEngine
lilacheden Dec 24, 2024
c99b23a
lowercase engine label
lilacheden Dec 24, 2024
7300db2
move to azure openai classification engines
lilacheden Dec 24, 2024
1e6c3e4
set chat_api format for openai
lilacheden Jan 5, 2025
7accd69
update secrets
lilacheden Jan 5, 2025
9a07989
delete old inference engine
lilacheden Jan 5, 2025
2a06b95
Merge branch 'main' into metrics_fix
lilacheden Jan 5, 2025
80a619f
update secrets
lilacheden Jan 5, 2025
b80e565
add numeric and verbal postprocessors
lilacheden Jan 5, 2025
85ac8be
fix rag judge example
lilacheden Jan 5, 2025
a1f4714
add numeric and verbal rag judge templates
lilacheden Jan 5, 2025
86f73c5
add rag judges that use the new templates
lilacheden Jan 5, 2025
07b2256
fix import
lilacheden Jan 5, 2025
a871319
rename metrics with correct template
lilacheden Jan 5, 2025
51ce53d
avoid import from prepare
lilacheden Jan 5, 2025
c40a699
remove old metrics
lilacheden Jan 5, 2025
33ed497
Merge branch 'main' into metrics_fix
lilacheden Jan 5, 2025
75ee255
add token overlap based context relevance and answer relevance metrics
lilacheden Jan 5, 2025
056a313
add postprocessors tests
lilacheden Jan 6, 2025
d614f39
keep only recommended rag llmaj, deprecate old path to metrics
lilacheden Jan 13, 2025
761ecff
Merge branch 'main' into metrics_fix
lilacheden Jan 13, 2025
4afe9f6
update secret
lilacheden Jan 13, 2025
eda05dc
update secret again
lilacheden Jan 13, 2025
830de60
fix typo
lilacheden Jan 13, 2025
6eba9d1
remove gen_ai from inference test
lilacheden Jan 13, 2025
cb0d74f
comment out input_tokens test
lilacheden Jan 13, 2025
410b049
Merge branch 'main' into metrics_fix
lilacheden Jan 13, 2025
3379b06
Merge branch 'main' into metrics_fix
lilacheden Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/evaluate_rag_using_binary_llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
# all available models for this judge are under "catalog.engines.classification"
mixtral_engine = "engines.classification.mixtral_8x7b_instruct_v01_wml"
correctness_judge_metric_mixtral = (
f"{metric_name}[{mapping_override}, model={mixtral_engine}]"
f"{metric_name}[{mapping_override}, inference_model={mixtral_engine}]"
)

metrics = [correctness_judge_metric_llama, correctness_judge_metric_mixtral]
Expand Down
17 changes: 11 additions & 6 deletions prepare/engines/classification/classification_engines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unitxt import add_to_catalog
from unitxt.inference import (
AzureOpenAIInferenceEngine,
IbmGenAiInferenceEngine,
OpenAiInferenceEngine,
RITSInferenceEngine,
WMLInferenceEngineGeneration,
)
Expand All @@ -23,8 +23,12 @@ def get_inference_engine(model_name, framework_name):
decoding_method="greedy",
)
if framework_name == "openai":
return OpenAiInferenceEngine(
model_name=model_name, logprobs=True, max_tokens=5, temperature=0.0
return AzureOpenAIInferenceEngine(
model_name=model_name,
logprobs=True,
max_tokens=5,
temperature=0.0,
top_logprobs=5,
)
if framework_name == "rits":
return RITSInferenceEngine(
Expand All @@ -34,9 +38,10 @@ def get_inference_engine(model_name, framework_name):


model_names_to_infer_framework = {
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml", "rits"],
"meta-llama/llama-3-70b-instruct": ["ibm_gen_ai"],
"gpt-4-turbo": ["openai"],
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml", "rits", "ibm_gen_ai"],
"meta-llama/llama-3-3-70b-instruct": ["ibm_wml", "rits"],
"gpt-4-turbo-2024-04-09": ["openai"],
"gpt-4o-2024-08-06": ["openai"],
"mistralai/mixtral-8x7b-instruct-v01": ["ibm_wml", "ibm_gen_ai", "rits"],
"meta-llama/llama-3-1-405b-instruct-fp8": ["ibm_gen_ai", "rits"],
"meta-llama/llama-3-405b-instruct": ["ibm_wml"],
Expand Down
65 changes: 0 additions & 65 deletions prepare/metrics/llm_as_judge/binary_judge.py

This file was deleted.

94 changes: 94 additions & 0 deletions prepare/metrics/llm_as_judge/rag_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unitxt import add_to_catalog
from unitxt.artifact import UnitxtArtifactNotFoundError, fetch_artifact
from unitxt.inference import GenericInferenceEngine
from unitxt.llm_as_judge import (
TaskBasedLLMasJudge,
)

metric_type_to_template_dict = {
"faithfulness": {
"q_c_a": "judge_with_question_simplified",
"c_a": "judge_no_question_simplified",
},
"context_relevance": {"q_c_ares": "judge_context_relevance_ares"},
"correctness_holistic": {"q_c_a": "judge_correctness_simple"},
"answer_correctness": {"q_a_gt_loose": "judge_loose_match_no_context"},
"answer_relevance": {"q_a": "judge_answer_relevance"},
}
metric_type_to_realization = {
"faithfulness": "_verbal",
"context_relevance": "_numeric",
"correctness_holistic": "_numeric",
"answer_correctness": "_numeric",
"answer_relevance": "_numeric",
}

generic_engine_label = "generic_inference_engine"
inference_models = {
"llama_3_1_70b_instruct_wml": "engines.classification.llama_3_1_70b_instruct_wml",
generic_engine_label: GenericInferenceEngine(),
}


def get_prediction_field(metric_type):
return None if metric_type == "context_relevance" else "answer"


for metric_type, template_dict in metric_type_to_template_dict.items():
for template_short_name, template_name in template_dict.items():
task_name = f"tasks.rag_eval.{metric_type}.binary"
for logprobs_label in [
"",
"_logprobs",
metric_type_to_realization[metric_type],
]:
use_logprobs = logprobs_label == "_logprobs"
template = (
f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}"
)
try:
t = fetch_artifact(template)[0]
except UnitxtArtifactNotFoundError:
continue
for inf_label, inference_model in inference_models.items():
if (
use_logprobs and inf_label == generic_engine_label
): # engine GenericInferenceEngine does not support logprobs
continue

metric_label = f"{metric_type}_{template_short_name}{logprobs_label}"
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
)

new_catalog_name = f"metrics.rag.{metric_type}.{inf_label}_{template_short_name}{logprobs_label}"

add_to_catalog(
metric,
new_catalog_name,
overwrite=True,
)

if logprobs_label in ["_logprobs", ""]:
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
__deprecated_msg__=f"This metric should be replaced with {new_catalog_name}",
)
# for backwards compatibility: keep also legacy path to metrics
add_to_catalog(
metric,
f"metrics.llm_as_judge.binary.{inf_label}_{metric_label}",
overwrite=True,
)
13 changes: 11 additions & 2 deletions prepare/metrics/rag_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,17 @@ def test_answer_correctness(
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="ground_truths", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={
"task_data/reference_answers": "references",
"answer": "prediction",
},
not_exist_do_nothing=True,
),
Copy(
field_to_field={"ground_truths": "references"},
not_exist_do_nothing=True,
),
],
metric=base_catalog_name,
)
Expand Down
34 changes: 30 additions & 4 deletions prepare/metrics/rag_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
answer_reward = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="question", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={"task_data/question": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"question": "references"}, not_exist_do_nothing=True),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
Expand All @@ -17,11 +20,34 @@
metric="metrics.reward.deberta_v3_large_v2",
)
add_to_catalog(answer_reward, "metrics.rag.answer_reward", overwrite=True)

answer_token_overlap = MetricPipeline(
main_score="recall",
preprocess_steps=[
Copy(
field_to_field={"task_data/question": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"question": "references"}, not_exist_do_nothing=True),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
ListFieldValues(fields=["references"], to_field="references"),
],
metric="metrics.token_overlap",
)
add_to_catalog(
answer_token_overlap, "metrics.rag.answer_relevance.token_recall", overwrite=True
)

answer_inference = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={"task_data/contexts": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric="metrics.perplexity_nli.t5_nli_mixture",
)
Expand Down
11 changes: 9 additions & 2 deletions prepare/metrics/rag_context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
("perplexity_flan_t5_small", "metrics.perplexity_q.flan_t5_small", "perplexity"),
("sentence_bert_bge", "metrics.sentence_bert.bge_large_en_1_5", "sbert_score"),
("sentence_bert_mini_lm", "metrics.sentence_bert.minilm_l12_v2", "sbert_score"),
("token_precision", "metrics.token_overlap", "precision"),
]:
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
Copy(
field_to_field={
"task_data/contexts": "references",
"question": "prediction",
},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric=base_catalog_name,
)
Expand Down
10 changes: 8 additions & 2 deletions prepare/metrics/rag_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={
"task_data/contexts": "references",
"answer": "prediction",
},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric=base_catalog_name,
)
Expand Down
14 changes: 11 additions & 3 deletions prepare/metrics/rag_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
}
]

risk_names = ["groundedness", "context_relevance", "answer_relevance"]
risk_names_to_pred_field = {
"groundedness": "answer",
"context_relevance": "question",
"answer_relevance": "answer",
}


for granite_risk_name in risk_names:
for granite_risk_name, pred_field in risk_names_to_pred_field.items():
metric_name = f"""granite_guardian_{granite_risk_name}"""
metric = GraniteGuardianWMLMetric(
main_score=metric_name,
Expand All @@ -28,7 +32,11 @@
preprocess_steps=[
Copy(
field_to_field={field: f"task_data/{field}" for field in rag_fields},
not_exist_ok=True,
not_exist_do_nothing=True,
),
Copy(
field_to_field={"prediction": f"task_data/{pred_field}"},
not_exist_do_nothing=True,
),
Set(
fields={
Expand Down
24 changes: 24 additions & 0 deletions prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ExtractArenaHardNumericalJudgment,
ExtractMtBenchLabelJudgment,
ExtractMtBenchRatingJudgment,
ExtractVerbalJudgementBadGood,
ExtractVerbalJudgment,
ExtractWithRegex,
FirstCharacter,
FixWhiteSpace,
Expand All @@ -22,6 +24,7 @@
RegexParser,
RemoveArticles,
RemovePunctuations,
ScaleNumberToZeroOneReturnZeroIfFails,
StanceToProCon,
StringEquals,
StrToFloatFormat,
Expand Down Expand Up @@ -271,3 +274,24 @@ def add_processor_and_operator_to_catalog(
add_processor_and_operator_to_catalog(
artifact_name="fix_whitespace", operator=FixWhiteSpace(), overwrite=True
)

add_processor_and_operator_to_catalog(
artifact_name="scale_0_10_to_0_1",
operator=ScaleNumberToZeroOneReturnZeroIfFails(),
overwrite=True,
process_references=False,
)

add_processor_and_operator_to_catalog(
artifact_name="extract_verbal_judgement",
operator=ExtractVerbalJudgment(),
overwrite=True,
process_references=False,
)

add_processor_and_operator_to_catalog(
artifact_name="extract_verbal_judgement_bad_good",
operator=ExtractVerbalJudgementBadGood(),
overwrite=True,
process_references=False,
)
File renamed without changes.
Loading
Loading