diff --git a/flexeval/core/eval_setups.py b/flexeval/core/eval_setups.py index 9bddf91..6f44e92 100644 --- a/flexeval/core/eval_setups.py +++ b/flexeval/core/eval_setups.py @@ -15,7 +15,7 @@ from .metric import Metric from .metric.tokenizer import Tokenizer from .multiple_choice_dataset import MultipleChoiceDataset -from .prompt_template import PromptTemplate +from .prompt_template import PromptTemplate, instantiate_prompt_template_from_string from .text_dataset import TextDataset @@ -78,13 +78,17 @@ class Generation(EvalSetup): """ eval_dataset: GenerationDataset - prompt_template: PromptTemplate + prompt_template: PromptTemplate | str gen_kwargs: dict[str, Any] few_shot_generator: FewShotGenerator | None = None metrics: list[Metric] | Metric | None = None batch_size: int = 4 max_instances: int | None = None + def __post_init__(self) -> None: + if isinstance(self.prompt_template, str): + self.prompt_template = instantiate_prompt_template_from_string(self.prompt_template) + def evaluate_lm( self, language_model: LanguageModel, @@ -113,11 +117,15 @@ class MultipleChoice(EvalSetup): """ eval_dataset: MultipleChoiceDataset - prompt_template: PromptTemplate + prompt_template: PromptTemplate | str few_shot_generator: FewShotGenerator | None = None batch_size: int = 4 max_instances: int | None = None + def __post_init__(self) -> None: + if isinstance(self.prompt_template, str): + self.prompt_template = instantiate_prompt_template_from_string(self.prompt_template) + def evaluate_lm( self, language_model: LanguageModel, diff --git a/flexeval/core/prompt_template/__init__.py b/flexeval/core/prompt_template/__init__.py index b928877..e8d4b4a 100644 --- a/flexeval/core/prompt_template/__init__.py +++ b/flexeval/core/prompt_template/__init__.py @@ -1,2 +1,2 @@ from .base import PromptTemplate -from .jinja2 import Jinja2PromptTemplate +from .jinja2 import Jinja2PromptTemplate, instantiate_prompt_template_from_string diff --git a/flexeval/core/prompt_template/jinja2.py b/flexeval/core/prompt_template/jinja2.py index 756c92b..748db9b 100644 --- a/flexeval/core/prompt_template/jinja2.py +++ b/flexeval/core/prompt_template/jinja2.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import Any from flexeval.core.utils.jinja2_utils import JINJA2_ENV @@ -7,6 +8,13 @@ from .base import PromptTemplate +def instantiate_prompt_template_from_string(template_or_path: str) -> Jinja2PromptTemplate: + # Use `os.path.isfile` instead of `Path.is_file()` to avoid "OSError: [Errno 36] File name too long" + if os.path.isfile(template_or_path): # noqa: PTH113 + return Jinja2PromptTemplate(template_path=template_or_path) + return Jinja2PromptTemplate(template=template_or_path) + + class Jinja2PromptTemplate(PromptTemplate): """ Embed task inputs using Jinja2 template engine. diff --git a/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval.jsonnet index 710f203..b9a2fa9 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval.jsonnet @@ -17,12 +17,7 @@ References: reference_template: '{{ test }}\n\ncheck({{ entry_point }})\n', }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: '{{ prompt }}', - }, - }, + prompt_template: '{{ prompt }}', metrics: [ { class_path: 'CodeEval', init_args: { code_template: '{{ prompt }}{{ lm_output }}' } }, ], diff --git a/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval_tab_indent.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval_tab_indent.jsonnet index e55ca84..cb1d3ef 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval_tab_indent.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval_tab_indent.jsonnet @@ -13,11 +13,7 @@ original_config { reference_template: '{{ test | replace(" ", "\t") }}\n\ncheck({{ entry_point }})\n', }, }, - prompt_template+: { - init_args+: { - template: "{{ prompt | replace(' ', '\t') }}", - }, - }, + prompt_template: "{{ prompt | replace(' ', '\t') }}", metrics: [ { class_path: 'CodeEval', init_args: { code_template: '{{ prompt | replace(" ", "\t") }}{{ lm_output }}' } }, ], diff --git a/flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet index 8621ed2..01fac63 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet @@ -26,33 +26,28 @@ local dataset_base_args = { num_shots: 3, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - ## Question - {{ item.prompt }} - ## Test cases - ```python - {{ item.test_list | join('\n') }} - ``` - ## Code - ```python - {{ item.code }} - ``` - {% endfor %} - ## Question - {{ prompt }} - ## Test cases - ```python - {{ test_list | join('\n') }} - ``` - ## Code - ```python - |||, - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + ## Question + {{ item.prompt }} + ## Test cases + ```python + {{ item.test_list | join('\n') }} + ``` + ## Code + ```python + {{ item.code }} + ``` + {% endfor %} + ## Question + {{ prompt }} + ## Test cases + ```python + {{ test_list | join('\n') }} + ``` + ## Code + ```python + |||, metrics: [ { class_path: 'CodeEval' }, ], diff --git a/flexeval/preset_configs/EvalSetup/code_generation/mbpp_tab_indent.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/mbpp_tab_indent.jsonnet index ff4bc36..9344e9d 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/mbpp_tab_indent.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/mbpp_tab_indent.jsonnet @@ -8,31 +8,27 @@ local original_config = import './mbpp.jsonnet'; original_config { init_args+: { - prompt_template+: { - init_args+: { - template: ||| - {% for item in few_shot_data %} - ## Question - {{ item.prompt }} - ## Test cases - ```python - {{ item.test_list | join('\n') }} - ``` - ## Code - ```python - {{ item.code | replace(' ', '\t') }} - ``` - {% endfor %} - ## Question - {{ prompt }} - ## Test cases - ```python - {{ test_list | join('\n') }} - ``` - ## Code - ```python - |||, - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + ## Question + {{ item.prompt }} + ## Test cases + ```python + {{ item.test_list | join('\n') }} + ``` + ## Code + ```python + {{ item.code | replace(' ', '\t') }} + ``` + {% endfor %} + ## Question + {{ prompt }} + ## Test cases + ```python + {{ test_list | join('\n') }} + ``` + ## Code + ```python + |||, }, } diff --git a/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval.jsonnet index d3fecfc..93070c9 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval.jsonnet @@ -17,12 +17,7 @@ References: reference_template: '{{ test }}\n\ncheck({{ entry_point }})\n', }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: '{{ prompt }}', - }, - }, + prompt_template: '{{ prompt }}', metrics: [ { class_path: 'CodeEval', init_args: { code_template: '{{ prompt }}{{ lm_output }}' } }, ], diff --git a/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval_tab_indent.jsonnet b/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval_tab_indent.jsonnet index c6d4cb9..4cd08ba 100644 --- a/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval_tab_indent.jsonnet +++ b/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval_tab_indent.jsonnet @@ -13,11 +13,7 @@ original_config { reference_template: '{{ test | replace(" ", "\t") }}\n\ncheck({{ entry_point }})\n', }, }, - prompt_template+: { - init_args+: { - template: '{{ prompt | replace(" ", "\t") }}', - }, - }, + prompt_template: '{{ prompt | replace(" ", "\t") }}', metrics: [ { class_path: 'CodeEval', init_args: { code_template: '{{ prompt | replace(" ", "\t") }}{{ lm_output }}' } }, ], diff --git a/flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet index fb959f9..138617b 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet @@ -25,20 +25,15 @@ local dataset_base_args = { num_shots: 3, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - Passage: {{ item.passage | trim }} - Question: {{ item.question }} - Answer: "{{ item.references[0] }}" - {% endfor %} - Passage: {{ passage | trim }} - Question: {{ question }} - ||| + 'Answer: "', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Passage: {{ item.passage | trim }} + Question: {{ item.question }} + Answer: "{{ item.references[0] }}" + {% endfor %} + Passage: {{ passage | trim }} + Question: {{ question }} + ||| + 'Answer: "', metrics: [ { class_path: 'CharF1' }, { class_path: 'ExactMatch' }, diff --git a/flexeval/preset_configs/EvalSetup/en_generation/commonsense_qa.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/commonsense_qa.jsonnet index cb19d1e..41c5e95 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/commonsense_qa.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/commonsense_qa.jsonnet @@ -26,31 +26,26 @@ local dataset_base_args = { num_shots: 2, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - Choose the correct answer from the choices. - {% for item in few_shot_data %} - Choices: - 0. "{{ item.choices.text[0] }}" - 1. "{{ item.choices.text[1] }}" - 2. "{{ item.choices.text[2] }}" - 3. "{{ item.choices.text[3] }}" - 4. "{{ item.choices.text[4] }}" - Question: {{ item.question }} - Answer: "{{ item.references[0] }}" - {% endfor %} - Choices: - 0. "{{ choices.text[0] }}" - 1. "{{ choices.text[1] }}" - 2. "{{ choices.text[2] }}" - 3. "{{ choices.text[3] }}" - 4. "{{ choices.text[4] }}" - Question: {{question}} - ||| + 'Answer: "', - }, - }, + prompt_template: ||| + Choose the correct answer from the choices. + {% for item in few_shot_data %} + Choices: + 0. "{{ item.choices.text[0] }}" + 1. "{{ item.choices.text[1] }}" + 2. "{{ item.choices.text[2] }}" + 3. "{{ item.choices.text[3] }}" + 4. "{{ item.choices.text[4] }}" + Question: {{ item.question }} + Answer: "{{ item.references[0] }}" + {% endfor %} + Choices: + 0. "{{ choices.text[0] }}" + 1. "{{ choices.text[1] }}" + 2. "{{ choices.text[2] }}" + 3. "{{ choices.text[3] }}" + 4. "{{ choices.text[4] }}" + Question: {{question}} + ||| + 'Answer: "', metrics: [ { class_path: 'ExactMatch' }, ], diff --git a/flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet index 9c1e4f9..e7d71d1 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet @@ -27,18 +27,13 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - Q: {{ item.question }} - A: {{ item.references[0] }} - {% endfor %} - Q: {{ question }} - ||| + 'A:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Q: {{ item.question }} + A: {{ item.references[0] }} + {% endfor %} + Q: {{ question }} + ||| + 'A:', metrics: [ { class_path: 'ExactMatch', diff --git a/flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet index 123dfa3..9b59a1c 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet @@ -26,21 +26,15 @@ local dataset_base_args = { num_shots: 2, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Context: {{ item.context | trim }} - Question: {{ item.question }} - Answer: "{{ item.references[0] }}" - {% endfor %} - Context: {{ context | trim }} - Question: {{ question }} - ||| + 'Answer: "', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Context: {{ item.context | trim }} + Question: {{ item.question }} + Answer: "{{ item.references[0] }}" + {% endfor %} + Context: {{ context | trim }} + Question: {{ question }} + ||| + 'Answer: "', metrics: [ { class_path: 'CharF1' }, { class_path: 'ExactMatch' }, diff --git a/flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet index 291e351..921b970 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet @@ -28,19 +28,13 @@ local dataset_base_args = { num_shots: 0, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Question: {{ item.question }} - Answer: "{{ item.references[0] }}" - {% endfor %} - Question: {{ question }} - ||| + 'Answer: "', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Question: {{ item.question }} + Answer: "{{ item.references[0] }}" + {% endfor %} + Question: {{ question }} + ||| + 'Answer: "', metrics: [ { class_path: 'CharF1' }, { class_path: 'ExactMatch' }, diff --git a/flexeval/preset_configs/EvalSetup/en_generation/twitter_sentiment.jsonnet b/flexeval/preset_configs/EvalSetup/en_generation/twitter_sentiment.jsonnet index 8d0bbf1..ea138f6 100644 --- a/flexeval/preset_configs/EvalSetup/en_generation/twitter_sentiment.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_generation/twitter_sentiment.jsonnet @@ -27,19 +27,14 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - Classify the sentiment of the following tweet. - {% for item in few_shot_data %} - Tweet: {{ item.text }} - Sentiment: `{{ item.references[0] }}` - {% endfor %} - Tweet: {{ text }} - ||| + 'Sentiment: `', - }, - }, + prompt_template: ||| + Classify the sentiment of the following tweet. + {% for item in few_shot_data %} + Tweet: {{ item.text }} + Sentiment: `{{ item.references[0] }}` + {% endfor %} + Tweet: {{ text }} + ||| + 'Sentiment: `', metrics: [ { class_path: 'ExactMatch' }, ], diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet index e83816c..dd81893 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet @@ -18,7 +18,7 @@ local dataset_base_args = { '{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}', '{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}', ], - # answerKey is one of A, B, C, D, E, 1, 2, 3, 4 + // answerKey is one of A, B, C, D, E, 1, 2, 3, 4 answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', whitespace_before_choices: true, }; @@ -40,18 +40,12 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Question: {{ item.question }} - Answer:{{ item.choices[item.answer_index] }} - {% endfor %} - Question: {{ question }} - ||| + 'Answer:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Question: {{ item.question }} + Answer:{{ item.choices[item.answer_index] }} + {% endfor %} + Question: {{ question }} + ||| + 'Answer:', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet index 9879eef..3cbbea2 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet @@ -18,7 +18,7 @@ local dataset_base_args = { '{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}', '{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}', ], - # answerKey is one of A, B, C, D, E, 1, 2, 3, 4 + // answerKey is one of A, B, C, D, E, 1, 2, 3, 4 answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', whitespace_before_choices: true, }; @@ -40,18 +40,12 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Question: {{ item.question }} - Answer:{{ item.choices[item.answer_index] }} - {% endfor %} - Question: {{ question }} - ||| + 'Answer:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Question: {{ item.question }} + Answer:{{ item.choices[item.answer_index] }} + {% endfor %} + Question: {{ question }} + ||| + 'Answer:', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/commonsense_qa_mc.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/commonsense_qa_mc.jsonnet index 1230ca9..c236209 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/commonsense_qa_mc.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/commonsense_qa_mc.jsonnet @@ -31,18 +31,12 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Question: {{ item.question }} - Answer:{{ item.choices[item.answer_index] }} - {% endfor %} - Question: {{ question }} - ||| + 'Answer:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Question: {{ item.question }} + Answer:{{ item.choices[item.answer_index] }} + {% endfor %} + Question: {{ question }} + ||| + 'Answer:', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/hellaswag.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/hellaswag.jsonnet index 5be2391..b0bfa3b 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/hellaswag.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/hellaswag.jsonnet @@ -32,15 +32,10 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - {{ item.ctx }}{{ item.choices[item.answer_index] }} - {% endfor %} - ||| + '{{ ctx }}', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + {{ item.ctx }}{{ item.choices[item.answer_index] }} + {% endfor %} + ||| + '{{ ctx }}', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/openbookqa.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/openbookqa.jsonnet index 6f03955..a0c8d77 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/openbookqa.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/openbookqa.jsonnet @@ -31,18 +31,12 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - - template: ||| - {% for item in few_shot_data %} - Question: {{ item.question_stem }} - Answer:{{ item.choices[item.answer_index] }} - {% endfor %} - Question: {{ question_stem }} - ||| + 'Answer:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Question: {{ item.question_stem }} + Answer:{{ item.choices[item.answer_index] }} + {% endfor %} + Question: {{ question_stem }} + ||| + 'Answer:', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/piqa.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/piqa.jsonnet index fd0297f..701c292 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/piqa.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/piqa.jsonnet @@ -32,15 +32,10 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - {{ item.goal }}{{ item.choices[item.answer_index] }} - {% endfor %} - ||| + '{{ goal }}', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + {{ item.goal }}{{ item.choices[item.answer_index] }} + {% endfor %} + ||| + '{{ goal }}', }, } diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/xwinograd_en.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/xwinograd_en.jsonnet index 593c131..1f8afdf 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/xwinograd_en.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/xwinograd_en.jsonnet @@ -24,11 +24,6 @@ References: input_templates: { context: '{{ sentence.split("_")[0] }}' }, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: '{{ context }}', - }, - }, + prompt_template: '{{ context }}', }, } diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/aio.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/aio.jsonnet index 3f78508..b5c5536 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/aio.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/aio.jsonnet @@ -17,16 +17,11 @@ local dataset_base_args = { }, }; -local template_ = '{{ question }}答えは「'; - { class_path: 'Generation', init_args: { eval_dataset: dataset_base_args, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { template: template_, }, - }, + prompt_template: '{{ question }}答えは「', metrics: [ { class_path: 'CharF1', @@ -43,7 +38,7 @@ local template_ = '{{ question }}答えは「'; }, }, ], - gen_kwargs: { max_new_tokens: 64, stop_sequences: ['」'], }, + gen_kwargs: { max_new_tokens: 64, stop_sequences: ['」'] }, batch_size: 1, }, } diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/jcommonsenseqa.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/jcommonsenseqa.jsonnet index 47a4c4b..6e05d9e 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/jcommonsenseqa.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/jcommonsenseqa.jsonnet @@ -24,31 +24,31 @@ local template_ = ||| 以下はタスクを説明する指示と、追加の背景情報を提供する入力の組み合わせです。要求を適切に満たす回答を書いてください。 ### 指示 質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。回答の他には何も含めないことを厳守してください。 - + ### 入力: 質問:主に子ども向けのもので、イラストのついた物語が書かれているものはどれ? 選択肢:世界,写真集,絵本,論文,図鑑 ### 回答: 絵本 - + ### 入力: 質問:未成年者を監護・教育し,彼らを監督し,彼らの財産上の利益を守る法律上の義務をもつ人は? 選択肢:浮浪者,保護者,お坊さん,宗教者,預言者 ### 回答: 保護者 - + ### 入力: 質問:数字の1を表すときに使う体は? 選択肢:胸,肉球,背中,人差し指,親指 ### 回答: 人差し指 - + ### 入力: 質問:火を起こすとあらわれるもくもくするものは? 選択肢:歯の変色,ガス,中毒,爆発,煙 ### 回答: 煙 - + ### 入力: 質問:{{ question }} 選択肢:{{ choice0 }},{{ choice1 }},{{ choice2 }},{{ choice3 }},{{ choice4 }} @@ -59,14 +59,11 @@ local template_ = ||| class_path: 'Generation', init_args: { eval_dataset: dataset_base_args, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { template: template_, }, - }, + prompt_template: template_, metrics: [ { class_path: 'ExactMatch' }, ], - gen_kwargs: { max_new_tokens: 64, stop_sequences: ['\n\n'], }, + gen_kwargs: { max_new_tokens: 64, stop_sequences: ['\n\n'] }, batch_size: 1, }, } diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/jnli.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/jnli.jsonnet index 6f953f6..6f599e6 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/jnli.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/jnli.jsonnet @@ -30,21 +30,16 @@ local dataset_base_args = { num_shots: 3, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - 前提と仮説の関係を「中立」、「含意」、「矛盾」の中から回答してください。 - {% for item in few_shot_data %} - 前提:「{{ item.sentence1 }}」 - 仮説:「{{ item.sentence2 }}」 - 関係:「{{ item.references[0] }}」 - {% endfor %} - 前提:「{{ sentence1 }}」 - 仮説:「{{ sentence2 }}」 - ||| + '関係:「', - }, - }, + prompt_template: ||| + 前提と仮説の関係を「中立」、「含意」、「矛盾」の中から回答してください。 + {% for item in few_shot_data %} + 前提:「{{ item.sentence1 }}」 + 仮説:「{{ item.sentence2 }}」 + 関係:「{{ item.references[0] }}」 + {% endfor %} + 前提:「{{ sentence1 }}」 + 仮説:「{{ sentence2 }}」 + ||| + '関係:「', metrics: [ { class_path: 'ExactMatch' }, ], diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/jsquad.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/jsquad.jsonnet index 17a11ba..a7ab1fe 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/jsquad.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/jsquad.jsonnet @@ -23,31 +23,31 @@ local template_ = ||| 以下はタスクを説明する指示と、追加の背景情報を提供する入力の組み合わせです。要求を適切に満たす回答を書いてください。 ### 指示 質問に対する回答を文章から一言で抽出してください。回答は名詞で答えてください。 それ以外には何も含めないことを厳守してください。 - + ### 入力: 文章:聖武天皇 [SEP] 文武天皇の第一皇子として生まれたが、慶雲4年6月15日(707年7月18日)に7歳で父と死別、母・宮子も心的障害に陥ったため、その後は長らく会うことはなかった。物心がついて以後の天皇が病気の平癒した母との対面を果たしたのは齢37のときであった。このため、同年7月17日(707年8月18日)、父方の祖母・元明天皇(天智天皇皇女)が中継ぎの天皇として即位した。和銅7年6月25日(714年8月9日)には首皇子の元服が行われて同日正式に立太子されるも、病弱であったこと、皇親勢力と外戚である藤原氏との対立もあり、即位は先延ばしにされ、翌霊亀元年9月2日(715年10月3日)に伯母(文武天皇の姉)・元正天皇が「中継ぎの中継ぎ」として皇位を継ぐことになった。24歳のときに元正天皇より皇位を譲られて即位することになる。 質問:文武天皇の第一皇子として生まれたのは? ### 回答: 聖武天皇 - + ### 入力: 文章:通称 [SEP] 人名としての通称は通り名、二つ名、異名、などと呼ばれる事もある。近世までは、本名(実名)は「」と呼ばれ、公言は避ける習慣があった。そのため、人を呼ぶ時は「仮名」「字」などの通称、官職名を用いるのが一般的だった。今日でも「総理」「大臣」「社長」「専務」などと呼びかけに使うのがこれにあたる。 質問:人名としての通称は何と呼ばれているか ### 回答: 通り名、二つ名、異名 - + ### 入力: 文章:坂本龍一 [SEP] 2014年7月10日、所属事務所エイベックス・ミュージック・クリエイティヴから中咽頭癌であること、療養に専念するためにコンサート活動などを中止する旨が発表された。かつてはインタビューなどで度々自身の健康状態や体力に自信を表しており、コンサート等公演スケジュールを自身の健康に起因する理由でキャンセルしたことがなかった。 質問:坂本龍一が療養に専念するためコンサート活動などを中止すると発表したのはいつか。 ### 回答: 2014年7月10日 - + ### 入力: 文章:リリーフ [SEP] プレッシャーの比較的かからない状態で投げることができるので、若手投手のテストの場としたり、故障明けや登板間隔の開いた投手を調整目的で登板させることもある。敗戦処理であっても好投すれば次回から先発や接戦での中継ぎに起用されるようになる場合もあり、幸い打線の援護を受けてチームが逆転すれば勝利投手に輝くこともある。 質問:打線の援護を受けてチームが逆転するとどんな投手になる? ### 回答: 勝利投手 - + ### 入力: 文章:{{ context }} 質問:{{ question }} @@ -58,15 +58,12 @@ local template_ = ||| class_path: 'Generation', init_args: { eval_dataset: dataset_base_args, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { template: template_, }, - }, + prompt_template: template_, metrics: [ { class_path: 'CharF1' }, { class_path: 'ExactMatch' }, ], - gen_kwargs: { max_new_tokens: 64, stop_sequences: ['\n\n'], }, + gen_kwargs: { max_new_tokens: 64, stop_sequences: ['\n\n'] }, batch_size: 1, }, } diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/mgsm_ja.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/mgsm_ja.jsonnet index c5477aa..c2936c7 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/mgsm_ja.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/mgsm_ja.jsonnet @@ -27,18 +27,13 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - {{ item.question }} - {{ item.answer }} - {% endfor %} - 問題: {{ question }} - ||| + 'ステップごとの答え:', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + {{ item.question }} + {{ item.answer }} + {% endfor %} + 問題: {{ question }} + ||| + 'ステップごとの答え:', metrics: [ { class_path: 'ExactMatch', init_args: { lm_output_processor: { class_path: 'RegexExtractor', init_args: { pattern: '-?[0-9.,]+' } } } }, ], diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/wrime_pos_neg.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/wrime_pos_neg.jsonnet index 3deba3f..e3ac9cd 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/wrime_pos_neg.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/wrime_pos_neg.jsonnet @@ -29,19 +29,14 @@ local dataset_base_args = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - 文の極性について「ポジティブ」か「ネガティブ」かで答えてください。 - {% for item in few_shot_data %} - 文:{{ item.sentence }} - 極性:「{{ item.references[0] }}」 - {% endfor %} - 文:{{sentence}} - ||| + '極性:「', - }, - }, + prompt_template: ||| + 文の極性について「ポジティブ」か「ネガティブ」かで答えてください。 + {% for item in few_shot_data %} + 文:{{ item.sentence }} + 極性:「{{ item.references[0] }}」 + {% endfor %} + 文:{{sentence}} + ||| + '極性:「', metrics: [ { class_path: 'ExactMatch' }, ], diff --git a/flexeval/preset_configs/EvalSetup/ja_generation/xlsum_ja.jsonnet b/flexeval/preset_configs/EvalSetup/ja_generation/xlsum_ja.jsonnet index 0fcad68..f1ad771 100644 --- a/flexeval/preset_configs/EvalSetup/ja_generation/xlsum_ja.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_generation/xlsum_ja.jsonnet @@ -29,19 +29,14 @@ local dataset_base_args = { num_shots: 1, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - 文章を1〜3文で要約してください。 - {% for item in few_shot_data %} - 文章: {{ item.text }} - 要約: {{ item.references[0] }} - {% endfor %} - 文章: {{ text }} - ||| + '要約:', - }, - }, + prompt_template: ||| + 文章を1〜3文で要約してください。 + {% for item in few_shot_data %} + 文章: {{ item.text }} + 要約: {{ item.references[0] }} + {% endfor %} + 文章: {{ text }} + ||| + '要約:', metrics: [ { class_path: 'ROUGE', diff --git a/flexeval/preset_configs/EvalSetup/ja_multiple_choice/jcommonsenseqa_mc.jsonnet b/flexeval/preset_configs/EvalSetup/ja_multiple_choice/jcommonsenseqa_mc.jsonnet index 7d97fda..79bd148 100644 --- a/flexeval/preset_configs/EvalSetup/ja_multiple_choice/jcommonsenseqa_mc.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_multiple_choice/jcommonsenseqa_mc.jsonnet @@ -34,17 +34,12 @@ local dataset_base_args = { num_shots: 0, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - 問題:{{ item.question }} - 回答:「{{ item.choices[item.answer_index] }}」 - {% endfor %} - 問題:{{question}} - ||| + '回答:「', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + 問題:{{ item.question }} + 回答:「{{ item.choices[item.answer_index] }}」 + {% endfor %} + 問題:{{question}} + ||| + '回答:「', }, } diff --git a/flexeval/preset_configs/EvalSetup/ja_multiple_choice/xwinograd_ja.jsonnet b/flexeval/preset_configs/EvalSetup/ja_multiple_choice/xwinograd_ja.jsonnet index 1c9d238..b32b1e1 100644 --- a/flexeval/preset_configs/EvalSetup/ja_multiple_choice/xwinograd_ja.jsonnet +++ b/flexeval/preset_configs/EvalSetup/ja_multiple_choice/xwinograd_ja.jsonnet @@ -24,11 +24,6 @@ References: input_templates: { context: '{{ sentence.split("_")[0] }}' }, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: '{{ context }}', - }, - }, + prompt_template: '{{ context }}', }, } diff --git a/flexeval/preset_configs/EvalSetup/translation/wmt20_en_ja.jsonnet b/flexeval/preset_configs/EvalSetup/translation/wmt20_en_ja.jsonnet index 67c0079..13b90a8 100644 --- a/flexeval/preset_configs/EvalSetup/translation/wmt20_en_ja.jsonnet +++ b/flexeval/preset_configs/EvalSetup/translation/wmt20_en_ja.jsonnet @@ -25,18 +25,13 @@ local dataset = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - En: `{{ item.source }}` - Ja: `{{ item.references[0] }}` - {% endfor %} - En: `{{ source }}` - ||| + 'Ja: `', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + En: `{{ item.source }}` + Ja: `{{ item.references[0] }}` + {% endfor %} + En: `{{ source }}` + ||| + 'Ja: `', metrics: [ { class_path: 'BLEU', init_args: { tokenize_option: 'ja-mecab' } }, ], diff --git a/flexeval/preset_configs/EvalSetup/translation/wmt20_ja_en.jsonnet b/flexeval/preset_configs/EvalSetup/translation/wmt20_ja_en.jsonnet index b4663c7..a03421a 100644 --- a/flexeval/preset_configs/EvalSetup/translation/wmt20_ja_en.jsonnet +++ b/flexeval/preset_configs/EvalSetup/translation/wmt20_ja_en.jsonnet @@ -25,18 +25,13 @@ local dataset = { num_shots: 4, }, }, - prompt_template: { - class_path: 'Jinja2PromptTemplate', - init_args: { - template: ||| - {% for item in few_shot_data %} - Ja: `{{ item.source }}` - En: `{{ item.references[0] }}` - {% endfor %} - Ja: `{{ source }}` - ||| + 'En: `', - }, - }, + prompt_template: ||| + {% for item in few_shot_data %} + Ja: `{{ item.source }}` + En: `{{ item.references[0] }}` + {% endfor %} + Ja: `{{ source }}` + ||| + 'En: `', metrics: [ { class_path: 'BLEU', init_args: { tokenize_option: 'intl' } }, ], diff --git a/tests/core/prompt_template/test_jinja2.py b/tests/core/prompt_template/test_jinja2.py index e5fd33e..a8f290b 100644 --- a/tests/core/prompt_template/test_jinja2.py +++ b/tests/core/prompt_template/test_jinja2.py @@ -2,7 +2,7 @@ import pytest -from flexeval.core.prompt_template.jinja2 import Jinja2PromptTemplate +from flexeval.core.prompt_template.jinja2 import Jinja2PromptTemplate, instantiate_prompt_template_from_string from flexeval.utils import instantiate_from_config from tests.dummy_modules.generation_dataset import DummyGenerationDataset @@ -47,3 +47,23 @@ def test_if_jinja2_template_keep_trailing_newline() -> None: # note that the text block (||| ... |||) in the jsonnet file adds a newline at the end assert prompt_template.embed_inputs({}) == "Hello World!\n" + + +def test_instantiate_prompt_template_from_string_with_template_path() -> None: + # Create a temporary Jinja2 template file + template_content = "Hello, {{ name }}!" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jinja2") as temp_file: + temp_file.write(template_content) + temp_file.flush() + temp_file_path = temp_file.name + + prompt_template = instantiate_prompt_template_from_string(temp_file_path) + assert isinstance(prompt_template, Jinja2PromptTemplate) + assert prompt_template.template == template_content + + +def test_instantiate_prompt_template_from_string_with_template_string() -> None: + template_string = "Hello, {{ name }}!" + prompt_template = instantiate_prompt_template_from_string(template_string) + assert isinstance(prompt_template, Jinja2PromptTemplate) + assert prompt_template.template == template_string