Skip to content

Commit

Permalink
Merge pull request #121 from sbintuitions/direct_template
Browse files Browse the repository at this point in the history
Simplify `prompt_template` in configs
  • Loading branch information
ryokan0123 authored Jan 15, 2025
2 parents c81dd1e + 27799f5 commit 8ead327
Show file tree
Hide file tree
Showing 34 changed files with 261 additions and 374 deletions.
14 changes: 11 additions & 3 deletions flexeval/core/eval_setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .metric import Metric
from .metric.tokenizer import Tokenizer
from .multiple_choice_dataset import MultipleChoiceDataset
from .prompt_template import PromptTemplate
from .prompt_template import PromptTemplate, instantiate_prompt_template_from_string
from .text_dataset import TextDataset


Expand Down Expand Up @@ -78,13 +78,17 @@ class Generation(EvalSetup):
"""

eval_dataset: GenerationDataset
prompt_template: PromptTemplate
prompt_template: PromptTemplate | str
gen_kwargs: dict[str, Any]
few_shot_generator: FewShotGenerator | None = None
metrics: list[Metric] | Metric | None = None
batch_size: int = 4
max_instances: int | None = None

def __post_init__(self) -> None:
if isinstance(self.prompt_template, str):
self.prompt_template = instantiate_prompt_template_from_string(self.prompt_template)

def evaluate_lm(
self,
language_model: LanguageModel,
Expand Down Expand Up @@ -113,11 +117,15 @@ class MultipleChoice(EvalSetup):
"""

eval_dataset: MultipleChoiceDataset
prompt_template: PromptTemplate
prompt_template: PromptTemplate | str
few_shot_generator: FewShotGenerator | None = None
batch_size: int = 4
max_instances: int | None = None

def __post_init__(self) -> None:
if isinstance(self.prompt_template, str):
self.prompt_template = instantiate_prompt_template_from_string(self.prompt_template)

def evaluate_lm(
self,
language_model: LanguageModel,
Expand Down
2 changes: 1 addition & 1 deletion flexeval/core/prompt_template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import PromptTemplate
from .jinja2 import Jinja2PromptTemplate
from .jinja2 import Jinja2PromptTemplate, instantiate_prompt_template_from_string
8 changes: 8 additions & 0 deletions flexeval/core/prompt_template/jinja2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from __future__ import annotations

import os
from typing import Any

from flexeval.core.utils.jinja2_utils import JINJA2_ENV

from .base import PromptTemplate


def instantiate_prompt_template_from_string(template_or_path: str) -> Jinja2PromptTemplate:
# Use `os.path.isfile` instead of `Path.is_file()` to avoid "OSError: [Errno 36] File name too long"
if os.path.isfile(template_or_path): # noqa: PTH113
return Jinja2PromptTemplate(template_path=template_or_path)
return Jinja2PromptTemplate(template=template_or_path)


class Jinja2PromptTemplate(PromptTemplate):
"""
Embed task inputs using Jinja2 template engine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ References:
reference_template: '{{ test }}\n\ncheck({{ entry_point }})\n',
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: '{{ prompt }}',
},
},
prompt_template: '{{ prompt }}',
metrics: [
{ class_path: 'CodeEval', init_args: { code_template: '{{ prompt }}{{ lm_output }}' } },
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ original_config {
reference_template: '{{ test | replace(" ", "\t") }}\n\ncheck({{ entry_point }})\n',
},
},
prompt_template+: {
init_args+: {
template: "{{ prompt | replace(' ', '\t') }}",
},
},
prompt_template: "{{ prompt | replace(' ', '\t') }}",
metrics: [
{ class_path: 'CodeEval', init_args: { code_template: '{{ prompt | replace(" ", "\t") }}{{ lm_output }}' } },
],
Expand Down
49 changes: 22 additions & 27 deletions flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,28 @@ local dataset_base_args = {
num_shots: 3,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: |||
{% for item in few_shot_data %}
## Question
{{ item.prompt }}
## Test cases
```python
{{ item.test_list | join('\n') }}
```
## Code
```python
{{ item.code }}
```
{% endfor %}
## Question
{{ prompt }}
## Test cases
```python
{{ test_list | join('\n') }}
```
## Code
```python
|||,
},
},
prompt_template: |||
{% for item in few_shot_data %}
## Question
{{ item.prompt }}
## Test cases
```python
{{ item.test_list | join('\n') }}
```
## Code
```python
{{ item.code }}
```
{% endfor %}
## Question
{{ prompt }}
## Test cases
```python
{{ test_list | join('\n') }}
```
## Code
```python
|||,
metrics: [
{ class_path: 'CodeEval' },
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,27 @@ local original_config = import './mbpp.jsonnet';

original_config {
init_args+: {
prompt_template+: {
init_args+: {
template: |||
{% for item in few_shot_data %}
## Question
{{ item.prompt }}
## Test cases
```python
{{ item.test_list | join('\n') }}
```
## Code
```python
{{ item.code | replace(' ', '\t') }}
```
{% endfor %}
## Question
{{ prompt }}
## Test cases
```python
{{ test_list | join('\n') }}
```
## Code
```python
|||,
},
},
prompt_template: |||
{% for item in few_shot_data %}
## Question
{{ item.prompt }}
## Test cases
```python
{{ item.test_list | join('\n') }}
```
## Code
```python
{{ item.code | replace(' ', '\t') }}
```
{% endfor %}
## Question
{{ prompt }}
## Test cases
```python
{{ test_list | join('\n') }}
```
## Code
```python
|||,
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ References:
reference_template: '{{ test }}\n\ncheck({{ entry_point }})\n',
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: '{{ prompt }}',
},
},
prompt_template: '{{ prompt }}',
metrics: [
{ class_path: 'CodeEval', init_args: { code_template: '{{ prompt }}{{ lm_output }}' } },
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ original_config {
reference_template: '{{ test | replace(" ", "\t") }}\n\ncheck({{ entry_point }})\n',
},
},
prompt_template+: {
init_args+: {
template: '{{ prompt | replace(" ", "\t") }}',
},
},
prompt_template: '{{ prompt | replace(" ", "\t") }}',
metrics: [
{ class_path: 'CodeEval', init_args: { code_template: '{{ prompt | replace(" ", "\t") }}{{ lm_output }}' } },
],
Expand Down
23 changes: 9 additions & 14 deletions flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,15 @@ local dataset_base_args = {
num_shots: 3,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: |||
{% for item in few_shot_data %}
Passage: {{ item.passage | trim }}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Passage: {{ passage | trim }}
Question: {{ question }}
||| + 'Answer: "',
},
},
prompt_template: |||
{% for item in few_shot_data %}
Passage: {{ item.passage | trim }}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Passage: {{ passage | trim }}
Question: {{ question }}
||| + 'Answer: "',
metrics: [
{ class_path: 'CharF1' },
{ class_path: 'ExactMatch' },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,26 @@ local dataset_base_args = {
num_shots: 2,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: |||
Choose the correct answer from the choices.
{% for item in few_shot_data %}
Choices:
0. "{{ item.choices.text[0] }}"
1. "{{ item.choices.text[1] }}"
2. "{{ item.choices.text[2] }}"
3. "{{ item.choices.text[3] }}"
4. "{{ item.choices.text[4] }}"
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Choices:
0. "{{ choices.text[0] }}"
1. "{{ choices.text[1] }}"
2. "{{ choices.text[2] }}"
3. "{{ choices.text[3] }}"
4. "{{ choices.text[4] }}"
Question: {{question}}
||| + 'Answer: "',
},
},
prompt_template: |||
Choose the correct answer from the choices.
{% for item in few_shot_data %}
Choices:
0. "{{ item.choices.text[0] }}"
1. "{{ item.choices.text[1] }}"
2. "{{ item.choices.text[2] }}"
3. "{{ item.choices.text[3] }}"
4. "{{ item.choices.text[4] }}"
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Choices:
0. "{{ choices.text[0] }}"
1. "{{ choices.text[1] }}"
2. "{{ choices.text[2] }}"
3. "{{ choices.text[3] }}"
4. "{{ choices.text[4] }}"
Question: {{question}}
||| + 'Answer: "',
metrics: [
{ class_path: 'ExactMatch' },
],
Expand Down
19 changes: 7 additions & 12 deletions flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,13 @@ local dataset_base_args = {
num_shots: 4,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {
template: |||
{% for item in few_shot_data %}
Q: {{ item.question }}
A: {{ item.references[0] }}
{% endfor %}
Q: {{ question }}
||| + 'A:',
},
},
prompt_template: |||
{% for item in few_shot_data %}
Q: {{ item.question }}
A: {{ item.references[0] }}
{% endfor %}
Q: {{ question }}
||| + 'A:',
metrics: [
{
class_path: 'ExactMatch',
Expand Down
24 changes: 9 additions & 15 deletions flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,15 @@ local dataset_base_args = {
num_shots: 2,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {

template: |||
{% for item in few_shot_data %}
Context: {{ item.context | trim }}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Context: {{ context | trim }}
Question: {{ question }}
||| + 'Answer: "',
},
},
prompt_template: |||
{% for item in few_shot_data %}
Context: {{ item.context | trim }}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Context: {{ context | trim }}
Question: {{ question }}
||| + 'Answer: "',
metrics: [
{ class_path: 'CharF1' },
{ class_path: 'ExactMatch' },
Expand Down
20 changes: 7 additions & 13 deletions flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,13 @@ local dataset_base_args = {
num_shots: 0,
},
},
prompt_template: {
class_path: 'Jinja2PromptTemplate',
init_args: {

template: |||
{% for item in few_shot_data %}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Question: {{ question }}
||| + 'Answer: "',
},
},
prompt_template: |||
{% for item in few_shot_data %}
Question: {{ item.question }}
Answer: "{{ item.references[0] }}"
{% endfor %}
Question: {{ question }}
||| + 'Answer: "',
metrics: [
{ class_path: 'CharF1' },
{ class_path: 'ExactMatch' },
Expand Down
Loading

0 comments on commit 8ead327

Please sign in to comment.