Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Evals API][8/n] AnswerParsingScoringFn for MMLU #352

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions llama_stack/apis/scoring_functions/scoring_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,40 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)

from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from llama_stack.apis.common.type_system import ParamType


@json_schema_type
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None


# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?


@json_schema_type
class ScoringContextType(Enum):
llm_as_judge = "llm_as_judge"
answer_parsing = "answer_parsing"


@json_schema_type
class LLMAsJudgeContext(BaseModel):
type: Literal[ScoringContextType.llm_as_judge.value] = (
ScoringContextType.llm_as_judge.value
)
judge_model: str
prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field(
Expand All @@ -32,6 +45,26 @@ class LLMAsJudgeContext(BaseModel):
)


@json_schema_type
class AnswerParsingContext(BaseModel):
type: Literal[ScoringContextType.answer_parsing.value] = (
ScoringContextType.answer_parsing.value
)
parsing_regex: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)


ScoringContext = Annotated[
Union[
LLMAsJudgeContext,
AnswerParsingContext,
],
Field(discriminator="type"),
]


@json_schema_type
class ScoringFnDef(BaseModel):
identifier: str
Expand All @@ -40,14 +73,13 @@ class ScoringFnDef(BaseModel):
default_factory=dict,
description="Any additional metadata for this definition",
)
parameters: List[Parameter] = Field(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing parameters field as we can just use context for defining parameters associated with the scoring function.

description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
context: Optional[LLMAsJudgeContext] = None
context: Optional[ScoringContext] = Field(
description="Scoring function context used different answer extraction",
default=None,
)
# We can optionally add information here to support packaging of code, etc.


Expand Down
41 changes: 24 additions & 17 deletions llama_stack/providers/impls/meta_reference/scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,20 @@
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn,
)

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
LlmAsJudgeScoringFn,
)

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn,
)

from .config import MetaReferenceScoringConfig
from .scoring_fn.answer_parsing_scoring_fn import AnswerParsingScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn

FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]

LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
# Scoring functions with context that can be registered
REGISTERABLE_SCORING_FNS = {
Copy link
Contributor Author

@yanxi0830 yanxi0830 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each Registerable ScoringFn is mapped to a ScoringContextType, So that we are able to register a scoring function with custom judge_prompt/answer_extraction regex.

ScoringContextType.llm_as_judge.value: LlmAsJudgeScoringFn,
ScoringContextType.answer_parsing.value: AnswerParsingScoringFn,
}


class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
Expand All @@ -44,18 +41,24 @@ def __init__(
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
# keep track of scoring function id to impls
self.scoring_fn_id_impls = {}
# registerable scoring fn context to impls
self.registerable_scoring_fn_impls = {}

async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for context_type, impl_cls in REGISTERABLE_SCORING_FNS.items():
if context_type == ScoringContextType.llm_as_judge.value:
impl = impl_cls(inference_api=self.inference_api)
else:
impl = impl_cls()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
self.registerable_scoring_fn_impls[context_type] = impl

async def shutdown(self) -> None: ...

Expand All @@ -74,8 +77,12 @@ async def list_scoring_functions(self) -> List[ScoringFnDef]:
return scoring_fn_defs_list

async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
assert (
function_def.context is not None
), "Only ScoringFnDef with context set can be registered"
fn_impl = self.registerable_scoring_fn_impls[function_def.context.type]
fn_impl.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = fn_impl

async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re

from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy

from .fn_defs.answer_parsing_multiple_choice import answer_parsing_multiple_choice


class AnswerParsingScoringFn(BaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
answer_parsing_multiple_choice.identifier: answer_parsing_multiple_choice,
}

async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert (
fn_def.context is not None
and fn_def.context.type == ScoringContextType.answer_parsing.value
), f"AnswerParsingContext not found for {fn_def}."

expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

# parse answer according to regex
parsed_answer = None
for regex in fn_def.context.parsing_regex:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break

score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}

async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
aggregate_accuracy,
)

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
equality,
)
from .fn_defs.equality import equality


class EqualityScoringFn(BaseScoringFn):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType

MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
r"Answer\s*:​​​​​​", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"답\s*:",
r"答案\s*:",
r"答案\s*:",
r"答\s*:",
r"答\s*:",
r"答复\s*:",
r"答曰\s*:",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*:",
r"回答\s*:",
r"回答\s*:",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]

MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
)

answer_parsing_multiple_choice = ScoringFnDef(
identifier="meta-reference::answer_parsing_multiple_choice",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
context=AnswerParsingContext(
parsing_regex=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@
equality = ScoringFnDef(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=JUDGE_PROMPT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@
subset_of = ScoringFnDef(
identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)

from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
llm_as_judge_8b_correctness,
)
from .common import aggregate_average
from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness


class LlmAsJudgeScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns
A scoring_fn using LLM as Judge to produce score
"""

def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from .common import aggregate_accuracy

from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
subset_of,
)
from .fn_defs.subset_of import subset_of


class SubsetOfScoringFn(BaseScoringFn):
Expand Down
1 change: 1 addition & 0 deletions llama_stack/providers/tests/scoring/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def provider_scoring_functions():
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::answer_parsing_multiple_choice",
},
"braintrust": {
"braintrust::factuality",
Expand Down