Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LMMSEvalInferenceEngine #1301

Merged
merged 7 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/evaluate_image_text_to_text_lmms_eval_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from tqdm import tqdm
from unitxt import settings
from unitxt.api import evaluate, load_dataset
from unitxt.inference import LMMSEvalInferenceEngine
from unitxt.text_utils import print_dict

with settings.context(
disable_hf_datasets_cache=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document this new flag in the documentation in a relatively prominent way. (e.g. in one of the first tutorials that loads data from HF).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add this example to example main page, and disable the run of this examples by default in regressions.

):
inference_model = LMMSEvalInferenceEngine(
model_type="llava_hf",
model_args="pretrained=llava-hf/llava-interleave-qwen-0.5b-hf",
)

dataset = load_dataset(
card="cards.doc_vqa.lmms_eval",
template="templates.qa.with_context.title",
loader_limit=30,
streaming=True,
)

test_dataset = list(tqdm(dataset["test"], total=30))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does this do? Why do you nee tqdm?


predictions = inference_model.infer(test_dataset)
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you get same results as with the inference engine HFLlavaInferenceEngine?


print_dict(
evaluated_dataset[0],
keys_to_print=[
"source",
"media",
"references",
"processed_prediction",
"score",
],
)
87 changes: 87 additions & 0 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import os
import re
import uuid
from typing import Any, Dict, List, Literal, Optional, Union

from datasets import DatasetDict
Expand Down Expand Up @@ -1074,3 +1075,89 @@ def _infer(
results.append(result)

return results


def get_images(instance):
return instance


class LMMSEvalInferenceEngine(InferenceEngine, LazyLoadMixin):
model_type: str
model_args: Dict[str, str]
batch_size: int = 1
max_new_tokens: int = 32
temperature: float = 0.0
do_sample: bool = False
generate_until: List[str] = ["\n\n"]

_requirements_list = {
"lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.1.2",
}

def prepare_engine(self):
if not self.lazy_load:
self._prepare_engine()

def _prepare_engine(self):
import torch
from lmms_eval.api.registry import get_model

self.device = torch.device(
"mps"
if torch.backends.mps.is_available()
else 0
if torch.cuda.is_available()
else "cpu"
)

self.model = get_model(self.model_type).create_from_arg_string(
self.model_args,
{
"batch_size": self.batch_size,
"device": self.device,
},
)

def _is_loaded(self):
return hasattr(self, "model") and self.model is not None

def _infer(
self,
dataset: Union[List[Dict[str, Any]], DatasetDict],
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
if not self._is_loaded():
self._prepare_engine()

from lmms_eval.api.instance import Instance

temp_task_name = str(uuid.uuid4())

self.model.task_dict[temp_task_name] = DatasetDict({"test": dataset})

requests = []
for i, instance in enumerate(dataset):
requests.append(
Instance(
request_type="generate_until",
arguments=(
instance["source"],
{
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"do_sample": self.do_sample,
"until": self.generate_until,
},
lambda x: x["media"]["images"],
0,
temp_task_name,
"test",
),
idx=i,
metadata=(temp_task_name, i, 1),
)
)

self.model.task_dict.pop(temp_task_name)

return self.model.generate_until(requests)
Loading