-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
77d6a07
commit 9284cc4
Showing
3 changed files
with
42 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .balanced import BalancedFewShotGenerator | ||
from .base import FewShotGenerator | ||
from .fixed import FixedFewShotGenerator | ||
from .rand import RandomFewShotGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from .base import ChatInstance, FewShotGenerator, GenerationInstance, Instance, MultipleChoiceInstance | ||
|
||
|
||
class FixedFewShotGenerator(FewShotGenerator): | ||
def __init__(self, instance_class: str, instance_params: list[dict[str, Any]]) -> None: | ||
super().__init__(num_trials_to_avoid_leak=0) | ||
|
||
if instance_class == "GenerationInstance": | ||
instance_init = GenerationInstance | ||
elif instance_class == "MultipleChoiceInstance": | ||
instance_init = MultipleChoiceInstance | ||
elif instance_class == "ChatInstance": | ||
instance_init = ChatInstance | ||
else: | ||
msg = f"Unknown instance class: {instance_class}" | ||
raise ValueError(msg) | ||
|
||
self.instances = [instance_init(**params) for params in instance_params] | ||
|
||
def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]: | ||
return self.instances | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(instances={self.instances})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from flexeval import ChatInstance | ||
from flexeval.core.few_shot_generator.fixed import FixedFewShotGenerator | ||
|
||
|
||
def test_fixed_fewshot_generator() -> None: | ||
instance = ChatInstance( | ||
messages=[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there!"}] | ||
) | ||
generator = FixedFewShotGenerator( | ||
instance_class="ChatInstance", | ||
instance_params=[{"messages": instance.messages} for _ in range(5)], | ||
) | ||
assert generator() == [instance for _ in range(5)] |