Skip to content

Commit

Permalink
Merge branch 'main' into accelerate-rag-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel authored Jan 12, 2025
2 parents 59a816f + 0ed5ff6 commit 6af1893
Show file tree
Hide file tree
Showing 27 changed files with 692 additions and 27 deletions.
21 changes: 11 additions & 10 deletions examples/evaluate_image_text_to_text.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from unitxt import settings
from unitxt.api import evaluate, load_dataset
from unitxt.inference import HFLlavaInferenceEngine
from unitxt.inference import (
LMMSEvalInferenceEngine,
)

with settings.context(
disable_hf_datasets_cache=False,
):
inference_model = LMMSEvalInferenceEngine(
model_type="llava",
model_args={"pretrained": "liuhaotian/llava-v1.5-7b"},
max_new_tokens=128,
)
dataset = load_dataset(
card="cards.doc_vqa.lmms_eval",
template="templates.qa.with_context.title",
card="cards.websrc",
format="formats.chat_api",
loader_limit=10,
augmentor="augmentors.image.grey_scale",
# max_test_instances=20,
split="test",
)

model = HFLlavaInferenceEngine(
model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=32
)

predictions = model(dataset)
predictions = inference_model.infer(dataset)
results = evaluate(predictions=predictions, data=dataset)

print("Global Results:")
Expand Down
9 changes: 8 additions & 1 deletion prepare/cards/ai2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unitxt.catalog import add_to_catalog
from unitxt.image_operators import ToImage
from unitxt.operators import Cast, Rename
from unitxt.templates import MultipleChoiceTemplate
from unitxt.test_utils.card import test_card

card = TaskCard(
Expand All @@ -12,8 +13,14 @@
Set(fields={"context_type": "image"}),
Cast(field="answer", to="int"),
],
task="tasks.qa.multiple_choice.with_context",
task="tasks.qa.multiple_choice.with_context[metrics=[metrics.exact_match_mm]]",
templates="templates.qa.multiple_choice.with_context.no_intro.all",
default_template=MultipleChoiceTemplate(
input_format="{context}\n{question}\n{choices}\nAnswer with the option's letter from the given choices directly.",
choices_separator="\n",
target_field="answer",
enumerator="capitals",
),
__tags__={},
__description__=(
"AI2 Diagrams (AI2D) is a dataset of over 5000 grade school science diagrams with over 150000 rich annotations, their ground truth syntactic parses, and more than 15000 corresponding multiple choice questions."
Expand Down
35 changes: 34 additions & 1 deletion prepare/cards/chart_qa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unitxt.blocks import LoadHF, Set, TaskCard
from unitxt.catalog import add_to_catalog
from unitxt.collections_operators import Wrap
from unitxt.image_operators import ToImage
from unitxt.operators import Rename
from unitxt.splitters import RenameSplits
from unitxt.templates import MultiReferenceTemplate
from unitxt.test_utils.card import test_card

card = TaskCard(
Expand All @@ -14,7 +16,7 @@
ToImage(field="image", to_field="context"),
Set(fields={"context_type": "image"}),
],
task="tasks.qa.with_context.abstractive",
task="tasks.qa.with_context",
templates="templates.qa.with_context.all",
__tags__={
"license": "GPL-3.0",
Expand All @@ -31,3 +33,34 @@

test_card(card)
add_to_catalog(card, "cards.chart_qa", overwrite=True)


card = TaskCard(
loader=LoadHF(path="lmms-lab/ChartQA"),
preprocess_steps=[
Wrap(field="answer", inside="list", to_field="answers"),
ToImage(field="image", to_field="context"),
Set(fields={"context_type": "image"}),
],
task="tasks.qa.with_context.with_type[metrics=[metrics.relaxed_correctness]]",
templates="templates.qa.with_context.all",
default_template=MultiReferenceTemplate(
input_format="{context}\n{question}\nAnswer the question using a single word.",
references_field="answers",
__description__="lmms-evals default template for chartqa.",
),
__tags__={
"license": "GPL-3.0",
"multilinguality": "monolingual",
"modalities": ["image", "text"],
"size_categories": "10K<n<100K",
"task_categories": "question-answering",
"task_ids": "extractive-qa",
},
__description__=(
"ChartQA: A Benchmark for Question Answering about Charts with Visual and Logical Reasoning."
),
)

test_card(card)
add_to_catalog(card, "cards.chart_qa_lmms_eval", overwrite=True)
6 changes: 6 additions & 0 deletions prepare/cards/doc_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unitxt.image_operators import ToImage
from unitxt.operators import Copy
from unitxt.splitters import RenameSplits
from unitxt.templates import MultiReferenceTemplate
from unitxt.test_utils.card import test_card

for language in ["en", "fr"]:
Expand Down Expand Up @@ -48,6 +49,11 @@
],
task="tasks.qa.with_context.abstractive[metrics=[metrics.anls]]",
templates="templates.qa.with_context.all",
default_template=MultiReferenceTemplate(
input_format="{context}\n{question}\nAnswer the question using a single word or phrase.",
references_field="answers",
__description__="lmms-evals default template for docvqa.",
),
__tags__={
"license": "apache-2.0",
"multilinguality": "monolingual",
Expand Down
50 changes: 48 additions & 2 deletions prepare/cards/info_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from unitxt.collections_operators import Wrap
from unitxt.image_operators import ToImage
from unitxt.operators import Rename
from unitxt.splitters import SplitRandomMix
from unitxt.splitters import RenameSplits, SplitRandomMix
from unitxt.templates import MultiReferenceTemplate
from unitxt.test_utils.card import test_card

card = TaskCard(
Expand All @@ -18,7 +19,17 @@
Set(fields={"context_type": "image"}),
],
task="tasks.qa.with_context.abstractive[metrics=[metrics.anls]]",
templates="templates.qa.with_context.all",
templates=[
MultiReferenceTemplate(
input_format="{context}\n{question}\nAnswer the question using a single word or phrase.",
references_field="answers",
)
],
default_template=MultiReferenceTemplate(
input_format="{context}\n{question}\nAnswer the question using a single word or phrase.",
references_field="answers",
__description__="lmms-evals default template for info_vqa.",
),
__tags__={
"license": "Unknown",
"multilinguality": "monolingual",
Expand All @@ -34,3 +45,38 @@

test_card(card)
add_to_catalog(card, "cards.info_vqa", overwrite=True)


card = TaskCard(
loader=LoadHF(
path="lmms-lab/DocVQA",
name="InfographicVQA",
data_classification_policy=["public"],
),
preprocess_steps=[
RenameSplits(mapper={"validation": "test"}),
ToImage(field="image", to_field="context"),
Set(fields={"context_type": "image"}),
],
task="tasks.qa.with_context.abstractive[metrics=[metrics.anls]]",
templates="templates.qa.with_context.all",
default_template=MultiReferenceTemplate(
input_format="{context}\n{question}\nAnswer the question using a single word or phrase.",
references_field="answers",
__description__="lmms-evals default template for infovqa.",
),
__tags__={
"license": "apache-2.0",
"multilinguality": "monolingual",
"modalities": ["image", "text"],
"size_categories": "10K<n<100K",
"task_categories": "question-answering",
"task_ids": "extractive-qa",
},
__description__=(
"InfographicVQA is a dataset that comprises a diverse collection of infographics along with natural language questions and answers annotations. The collected questions require methods to jointly reason over the document layout, textual content, graphical elements, and data visualizations. We curate the dataset with emphasis on questions that require elementary reasoning and basic arithmetic skills."
),
)

test_card(card)
add_to_catalog(card, "cards.info_vqa_lmms_eval", overwrite=True)
8 changes: 7 additions & 1 deletion prepare/cards/websrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unitxt.collections_operators import Wrap
from unitxt.image_operators import DecodeImage, ToImage
from unitxt.splitters import RenameSplits
from unitxt.templates import MultiReferenceTemplate
from unitxt.test_utils.card import test_card

card = TaskCard(
Expand All @@ -15,8 +16,13 @@
ToImage(field="context"),
Set(fields={"context_type": "image"}),
],
task="tasks.qa.with_context.abstractive",
task="tasks.qa.with_context.with_domain[metrics=[metrics.websrc_squad_f1]]",
templates="templates.qa.with_context.all",
default_template=MultiReferenceTemplate(
input_format="{context}\nAnswer the question using a single word or phrase.\n{question}",
references_field="answers",
__description__="lmms-evals default template for websrc.",
),
__tags__={
"license": "Unknown",
"multilinguality": "monolingual",
Expand Down
31 changes: 31 additions & 0 deletions prepare/metrics/exact_match_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unitxt import add_to_catalog
from unitxt.metrics import ExactMatchMM
from unitxt.test_utils.metrics import test_metric

metric = ExactMatchMM(n_resamples=None)

predictions = ["A", "B", "C"]
references = [["B"], ["A"], ["C"]]

instance_targets = [
{"exact_match_mm": 0.0, "score": 0.0, "score_name": "exact_match_mm"},
{"exact_match_mm": 0.0, "score": 0.0, "score_name": "exact_match_mm"},
{"exact_match_mm": 1.0, "score": 1.0, "score_name": "exact_match_mm"},
]

global_target = {
"exact_match_mm": 0.33,
"score": 0.33,
"score_name": "exact_match_mm",
"num_of_instances": 3,
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

add_to_catalog(metric, "metrics.exact_match_mm", overwrite=True)
44 changes: 44 additions & 0 deletions prepare/metrics/relaxed_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from unitxt import add_to_catalog
from unitxt.metrics import RelaxedCorrectness
from unitxt.test_utils.metrics import test_metric

# from cvar_pyutils.debugging_tools import set_remote_debugger
# set_remote_debugger('9.61.188.58', 55557)
metric = RelaxedCorrectness(n_resamples=None)

predictions = ["10", "30"]
references = [["14"], ["30"]]

# how to create a metric which isn't updated in every sample when using UNITXT?
instance_targets = [
{
"relaxed_overall": 0.0,
"relaxed_human_split": 0.0,
"score": 0.0,
"score_name": "relaxed_overall",
},
{
"relaxed_overall": 1.0,
"relaxed_augmented_split": 1.0,
"score": 1.0,
"score_name": "relaxed_overall",
},
]

global_target = {
"relaxed_overall": 0.5,
"relaxed_human_split": 0.0,
"relaxed_augmented_split": 1.0,
"score": 0.5,
"score_name": "relaxed_overall",
"num_of_instances": 2,
}
outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
task_data=[{"type": "human_test"}, {"type": "augmented_test"}],
)
add_to_catalog(metric, "metrics.relaxed_correctness", overwrite=True)
39 changes: 39 additions & 0 deletions prepare/metrics/websrc_squad_f1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unitxt import add_to_catalog
from unitxt.metrics import WebsrcSquadF1
from unitxt.test_utils.metrics import test_metric

# from cvar_pyutils.debugging_tools import set_remote_debugger
# set_remote_debugger('9.148.189.104', 55557)
metric = WebsrcSquadF1(n_resamples=None)

predictions = ["The 2nd", "The 1st"]
references = [["The 2nd"], ["The 2nd"]]

# how to create a metric which isn't updated in every sample when using UNITXT?
instance_targets = [
{
"websrc_squad_f1": 1.0,
"score": 1.0,
"score_name": "websrc_squad_f1",
},
{
"websrc_squad_f1": 0.5,
"score": 0.5,
"score_name": "websrc_squad_f1",
},
]
global_target = {
"num_of_instances": 2,
"websrc_squad_f1": 0.75,
"score": 0.75,
"score_name": "websrc_squad_f1",
}
outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
task_data=[{"domain": "movie"}, {"domain": "movie"}],
)
add_to_catalog(metric, "metrics.websrc_squad_f1", overwrite=True)
46 changes: 46 additions & 0 deletions prepare/tasks/qa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,49 @@
"tasks.qa.open",
overwrite=True,
)


add_to_catalog(
Task(
__description__="""This is the Question Answering Task with provided context (which is a either text, image, audio, table , or dialog) and additional field called type.
The 'tasks.qa.open' should be used if there is no context. One or more ground truth answers can be provided in the 'answers' field.
By default, classical Rouge metric is used , but list of additional applicable metrics can be found under 'metrics.qa' in the Unitxt catalog.
""",
input_fields={
"context": Union[Text, Image, Audio, Table, Dialog],
"context_type": str,
"question": str,
},
reference_fields={"answers": List[str], "type": str},
prediction_type=str,
metrics=["metrics.rouge"],
augmentable_inputs=["context", "question"],
defaults={"answers": []},
default_template="templates.qa.with_context",
),
"tasks.qa.with_context.with_type",
overwrite=True,
)


add_to_catalog(
Task(
__description__="""This is the Question Answering Task with provided context (which is a either text, image, audio, table , or dialog) and additional field called domain.
The 'tasks.qa.open' should be used if there is no context. One or more ground truth answers can be provided in the 'answers' field.
By default, classical Rouge metric is used , but list of additional applicable metrics can be found under 'metrics.qa' in the Unitxt catalog.
""",
input_fields={
"context": Union[Text, Image, Audio, Table, Dialog],
"context_type": str,
"question": str,
},
reference_fields={"answers": List[str], "domain": str},
prediction_type=str,
metrics=["metrics.rouge"],
augmentable_inputs=["context", "question"],
defaults={"answers": []},
default_template="templates.qa.with_context",
),
"tasks.qa.with_context.with_domain",
overwrite=True,
)
Loading

0 comments on commit 6af1893

Please sign in to comment.