Skip to content

Commit

Permalink
Merge pull request #120 from sbintuitions/impl_fixed_fewshot
Browse files Browse the repository at this point in the history
Implement `FixedFewShotGenerator`
  • Loading branch information
ryokan0123 authored Jan 14, 2025
2 parents 77d6a07 + 9284cc4 commit c81dd1e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
1 change: 1 addition & 0 deletions flexeval/core/few_shot_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .balanced import BalancedFewShotGenerator
from .base import FewShotGenerator
from .fixed import FixedFewShotGenerator
from .rand import RandomFewShotGenerator
28 changes: 28 additions & 0 deletions flexeval/core/few_shot_generator/fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import Any

from .base import ChatInstance, FewShotGenerator, GenerationInstance, Instance, MultipleChoiceInstance


class FixedFewShotGenerator(FewShotGenerator):
def __init__(self, instance_class: str, instance_params: list[dict[str, Any]]) -> None:
super().__init__(num_trials_to_avoid_leak=0)

if instance_class == "GenerationInstance":
instance_init = GenerationInstance
elif instance_class == "MultipleChoiceInstance":
instance_init = MultipleChoiceInstance
elif instance_class == "ChatInstance":
instance_init = ChatInstance
else:
msg = f"Unknown instance class: {instance_class}"
raise ValueError(msg)

self.instances = [instance_init(**params) for params in instance_params]

def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]:
return self.instances

def __repr__(self) -> str:
return f"{self.__class__.__name__}(instances={self.instances})"
13 changes: 13 additions & 0 deletions tests/core/few_show_generator/test_fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from flexeval import ChatInstance
from flexeval.core.few_shot_generator.fixed import FixedFewShotGenerator


def test_fixed_fewshot_generator() -> None:
instance = ChatInstance(
messages=[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there!"}]
)
generator = FixedFewShotGenerator(
instance_class="ChatInstance",
instance_params=[{"messages": instance.messages} for _ in range(5)],
)
assert generator() == [instance for _ in range(5)]

0 comments on commit c81dd1e

Please sign in to comment.