Skip to content

Commit

Permalink
Fixes and adjustment in rag metrics and related inference engines (#1466
Browse files Browse the repository at this point in the history
)

* add new classification engines and remove deleted llama-3-70b-instruct in bam
Signed-off-by: lilacheden <[email protected]>

* allow proprietary data on RITSInferenceEngine
Signed-off-by: lilacheden <[email protected]>

* propagate score_prefix from metricPipeline to its metric
Signed-off-by: lilacheden <[email protected]>

* Adjust autorag metrics to unitxt flow with rag.response_generation task
Signed-off-by: lilacheden <[email protected]>

* Adjust granite_guardian to unitxt flow
Signed-off-by: lilacheden <[email protected]>

* Add AzureOpenAIInferenceEngine
Signed-off-by: lilacheden <[email protected]>

* lowercase engine label
Signed-off-by: lilacheden <[email protected]>

* move to azure openai classification engines
Signed-off-by: lilacheden <[email protected]>

* set chat_api format for openai
Signed-off-by: lilacheden <[email protected]>

* update secrets
Signed-off-by: lilacheden <[email protected]>

* delete old inference engine
Signed-off-by: lilacheden <[email protected]>

* update secrets
Signed-off-by: lilacheden <[email protected]>

* add numeric and verbal postprocessors
Signed-off-by: lilacheden <[email protected]>

* fix rag judge example
Signed-off-by: lilacheden <[email protected]>

* add numeric and verbal rag judge templates
Signed-off-by: lilacheden <[email protected]>

* add rag judges that use the new templates
Signed-off-by: lilacheden <[email protected]>

* fix import
Signed-off-by: lilacheden <[email protected]>

* rename metrics with correct template
Signed-off-by: lilacheden <[email protected]>

* avoid import from prepare
Signed-off-by: lilacheden <[email protected]>

* remove old metrics
Signed-off-by: lilacheden <[email protected]>

* add token overlap based context relevance and answer relevance metrics
Signed-off-by: lilacheden <[email protected]>

* add postprocessors tests
Signed-off-by: lilacheden <[email protected]>

* keep only recommended rag llmaj, deprecate old path to metrics
Signed-off-by: lilacheden <[email protected]>

* update secret
Signed-off-by: lilacheden <[email protected]>

* update secret again
Signed-off-by: lilacheden <[email protected]>

* fix typo
Signed-off-by: lilacheden <[email protected]>

* remove gen_ai from inference test
Signed-off-by: lilacheden <[email protected]>

* comment out input_tokens test
Signed-off-by: lilacheden <[email protected]>
  • Loading branch information
lilacheden authored Jan 13, 2025
1 parent 6e7284b commit 1350d56
Show file tree
Hide file tree
Showing 101 changed files with 1,274 additions and 226 deletions.
2 changes: 1 addition & 1 deletion examples/evaluate_rag_using_binary_llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
# all available models for this judge are under "catalog.engines.classification"
mixtral_engine = "engines.classification.mixtral_8x7b_instruct_v01_wml"
correctness_judge_metric_mixtral = (
f"{metric_name}[{mapping_override}, model={mixtral_engine}]"
f"{metric_name}[{mapping_override}, inference_model={mixtral_engine}]"
)

metrics = [correctness_judge_metric_llama, correctness_judge_metric_mixtral]
Expand Down
17 changes: 11 additions & 6 deletions prepare/engines/classification/classification_engines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unitxt import add_to_catalog
from unitxt.inference import (
AzureOpenAIInferenceEngine,
IbmGenAiInferenceEngine,
OpenAiInferenceEngine,
RITSInferenceEngine,
WMLInferenceEngineGeneration,
)
Expand All @@ -23,8 +23,12 @@ def get_inference_engine(model_name, framework_name):
decoding_method="greedy",
)
if framework_name == "openai":
return OpenAiInferenceEngine(
model_name=model_name, logprobs=True, max_tokens=5, temperature=0.0
return AzureOpenAIInferenceEngine(
model_name=model_name,
logprobs=True,
max_tokens=5,
temperature=0.0,
top_logprobs=5,
)
if framework_name == "rits":
return RITSInferenceEngine(
Expand All @@ -34,9 +38,10 @@ def get_inference_engine(model_name, framework_name):


model_names_to_infer_framework = {
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml", "rits"],
"meta-llama/llama-3-70b-instruct": ["ibm_gen_ai"],
"gpt-4-turbo": ["openai"],
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml", "rits", "ibm_gen_ai"],
"meta-llama/llama-3-3-70b-instruct": ["ibm_wml", "rits"],
"gpt-4-turbo-2024-04-09": ["openai"],
"gpt-4o-2024-08-06": ["openai"],
"mistralai/mixtral-8x7b-instruct-v01": ["ibm_wml", "ibm_gen_ai", "rits"],
"meta-llama/llama-3-1-405b-instruct-fp8": ["ibm_gen_ai", "rits"],
"meta-llama/llama-3-405b-instruct": ["ibm_wml"],
Expand Down
65 changes: 0 additions & 65 deletions prepare/metrics/llm_as_judge/binary_judge.py

This file was deleted.

94 changes: 94 additions & 0 deletions prepare/metrics/llm_as_judge/rag_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unitxt import add_to_catalog
from unitxt.artifact import UnitxtArtifactNotFoundError, fetch_artifact
from unitxt.inference import GenericInferenceEngine
from unitxt.llm_as_judge import (
TaskBasedLLMasJudge,
)

metric_type_to_template_dict = {
"faithfulness": {
"q_c_a": "judge_with_question_simplified",
"c_a": "judge_no_question_simplified",
},
"context_relevance": {"q_c_ares": "judge_context_relevance_ares"},
"correctness_holistic": {"q_c_a": "judge_correctness_simple"},
"answer_correctness": {"q_a_gt_loose": "judge_loose_match_no_context"},
"answer_relevance": {"q_a": "judge_answer_relevance"},
}
metric_type_to_realization = {
"faithfulness": "_verbal",
"context_relevance": "_numeric",
"correctness_holistic": "_numeric",
"answer_correctness": "_numeric",
"answer_relevance": "_numeric",
}

generic_engine_label = "generic_inference_engine"
inference_models = {
"llama_3_1_70b_instruct_wml": "engines.classification.llama_3_1_70b_instruct_wml",
generic_engine_label: GenericInferenceEngine(),
}


def get_prediction_field(metric_type):
return None if metric_type == "context_relevance" else "answer"


for metric_type, template_dict in metric_type_to_template_dict.items():
for template_short_name, template_name in template_dict.items():
task_name = f"tasks.rag_eval.{metric_type}.binary"
for logprobs_label in [
"",
"_logprobs",
metric_type_to_realization[metric_type],
]:
use_logprobs = logprobs_label == "_logprobs"
template = (
f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}"
)
try:
t = fetch_artifact(template)[0]
except UnitxtArtifactNotFoundError:
continue
for inf_label, inference_model in inference_models.items():
if (
use_logprobs and inf_label == generic_engine_label
): # engine GenericInferenceEngine does not support logprobs
continue

metric_label = f"{metric_type}_{template_short_name}{logprobs_label}"
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
)

new_catalog_name = f"metrics.rag.{metric_type}.{inf_label}_{template_short_name}{logprobs_label}"

add_to_catalog(
metric,
new_catalog_name,
overwrite=True,
)

if logprobs_label in ["_logprobs", ""]:
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
__deprecated_msg__=f"This metric should be replaced with {new_catalog_name}",
)
# for backwards compatibility: keep also legacy path to metrics
add_to_catalog(
metric,
f"metrics.llm_as_judge.binary.{inf_label}_{metric_label}",
overwrite=True,
)
13 changes: 11 additions & 2 deletions prepare/metrics/rag_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,17 @@ def test_answer_correctness(
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="ground_truths", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={
"task_data/reference_answers": "references",
"answer": "prediction",
},
not_exist_do_nothing=True,
),
Copy(
field_to_field={"ground_truths": "references"},
not_exist_do_nothing=True,
),
],
metric=base_catalog_name,
)
Expand Down
34 changes: 30 additions & 4 deletions prepare/metrics/rag_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
answer_reward = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="question", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={"task_data/question": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"question": "references"}, not_exist_do_nothing=True),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
Expand All @@ -17,11 +20,34 @@
metric="metrics.reward.deberta_v3_large_v2",
)
add_to_catalog(answer_reward, "metrics.rag.answer_reward", overwrite=True)

answer_token_overlap = MetricPipeline(
main_score="recall",
preprocess_steps=[
Copy(
field_to_field={"task_data/question": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"question": "references"}, not_exist_do_nothing=True),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
ListFieldValues(fields=["references"], to_field="references"),
],
metric="metrics.token_overlap",
)
add_to_catalog(
answer_token_overlap, "metrics.rag.answer_relevance.token_recall", overwrite=True
)

answer_inference = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={"task_data/contexts": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric="metrics.perplexity_nli.t5_nli_mixture",
)
Expand Down
11 changes: 9 additions & 2 deletions prepare/metrics/rag_context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
("perplexity_flan_t5_small", "metrics.perplexity_q.flan_t5_small", "perplexity"),
("sentence_bert_bge", "metrics.sentence_bert.bge_large_en_1_5", "sbert_score"),
("sentence_bert_mini_lm", "metrics.sentence_bert.minilm_l12_v2", "sbert_score"),
("token_precision", "metrics.token_overlap", "precision"),
]:
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
Copy(
field_to_field={
"task_data/contexts": "references",
"question": "prediction",
},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric=base_catalog_name,
)
Expand Down
10 changes: 8 additions & 2 deletions prepare/metrics/rag_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
Copy(
field_to_field={
"task_data/contexts": "references",
"answer": "prediction",
},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric=base_catalog_name,
)
Expand Down
14 changes: 11 additions & 3 deletions prepare/metrics/rag_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
}
]

risk_names = ["groundedness", "context_relevance", "answer_relevance"]
risk_names_to_pred_field = {
"groundedness": "answer",
"context_relevance": "question",
"answer_relevance": "answer",
}


for granite_risk_name in risk_names:
for granite_risk_name, pred_field in risk_names_to_pred_field.items():
metric_name = f"""granite_guardian_{granite_risk_name}"""
metric = GraniteGuardianWMLMetric(
main_score=metric_name,
Expand All @@ -28,7 +32,11 @@
preprocess_steps=[
Copy(
field_to_field={field: f"task_data/{field}" for field in rag_fields},
not_exist_ok=True,
not_exist_do_nothing=True,
),
Copy(
field_to_field={"prediction": f"task_data/{pred_field}"},
not_exist_do_nothing=True,
),
Set(
fields={
Expand Down
24 changes: 24 additions & 0 deletions prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ExtractArenaHardNumericalJudgment,
ExtractMtBenchLabelJudgment,
ExtractMtBenchRatingJudgment,
ExtractVerbalJudgementBadGood,
ExtractVerbalJudgment,
ExtractWithRegex,
FirstCharacter,
FixWhiteSpace,
Expand All @@ -22,6 +24,7 @@
RegexParser,
RemoveArticles,
RemovePunctuations,
ScaleNumberToZeroOneReturnZeroIfFails,
StanceToProCon,
StringEquals,
StrToFloatFormat,
Expand Down Expand Up @@ -271,3 +274,24 @@ def add_processor_and_operator_to_catalog(
add_processor_and_operator_to_catalog(
artifact_name="fix_whitespace", operator=FixWhiteSpace(), overwrite=True
)

add_processor_and_operator_to_catalog(
artifact_name="scale_0_10_to_0_1",
operator=ScaleNumberToZeroOneReturnZeroIfFails(),
overwrite=True,
process_references=False,
)

add_processor_and_operator_to_catalog(
artifact_name="extract_verbal_judgement",
operator=ExtractVerbalJudgment(),
overwrite=True,
process_references=False,
)

add_processor_and_operator_to_catalog(
artifact_name="extract_verbal_judgement_bad_good",
operator=ExtractVerbalJudgementBadGood(),
overwrite=True,
process_references=False,
)
File renamed without changes.
Loading

0 comments on commit 1350d56

Please sign in to comment.