Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logprobs functionality #1111

Closed
wants to merge 11 commits into from
34 changes: 34 additions & 0 deletions prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ExtractWithRegex,
FirstCharacter,
GetStringAfter,
InferDictsToBinaryLogprobs,
LiteralEval,
LowerCase,
LowerCaseTillPunc,
Expand Down Expand Up @@ -369,3 +370,36 @@
"processors.extract_arena_hard_numerical_judgment",
overwrite=True,
)

add_to_catalog(
SequentialOperator(
steps=[
InferDictsToBinaryLogprobs(
neg_class_name="No",
pos_class_name="Yes",
num_logprobs_to_take=3,
field="prediction",
process_every_value=False,
),
]
),
"processors.infer_logprobs_to_yes_no_probs",
overwrite=True,
)

add_to_catalog(
SequentialOperator(
steps=[
InferDictsToBinaryLogprobs(
neg_class_name="No",
pos_class_name="Yes",
take_logprobs_from_end=True,
num_logprobs_to_take=3,
field="prediction",
process_every_value=False,
),
]
),
"processors.infer_last_token_logprobs_to_yes_no_probs",
overwrite=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"__type__": "sequential_operator",
"steps": [
{
"__type__": "infer_dicts_to_binary_logprobs",
"neg_class_name": "No",
"pos_class_name": "Yes",
"take_logprobs_from_end": true,
"num_logprobs_to_take": 3,
"field": "prediction",
"process_every_value": false
}
]
}
13 changes: 13 additions & 0 deletions src/unitxt/catalog/processors/infer_logprobs_to_yes_no_probs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"__type__": "sequential_operator",
"steps": [
{
"__type__": "infer_dicts_to_binary_logprobs",
"neg_class_name": "No",
"pos_class_name": "Yes",
"num_logprobs_to_take": 3,
"field": "prediction",
"process_every_value": false
}
]
}
38 changes: 35 additions & 3 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ class WMLInferenceEngineParams(Artifact):


class WMLInferenceEngine(
InferenceEngine, WMLInferenceEngineParamsMixin, PackageRequirementsMixin
InferenceEngine,
LogProbInferenceEngine,
WMLInferenceEngineParamsMixin,
PackageRequirementsMixin,
):
"""Runs inference using ibm-watsonx-ai.

Expand Down Expand Up @@ -466,19 +469,48 @@ def verify(self):
), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
super().verify()

def _infer(self, dataset):
def _load_model_and_params(self):
from ibm_watsonx_ai.foundation_models import ModelInference

model = ModelInference(
model_id=self.model_name,
deployment_id=self.deployment_id,
api_client=self.client,
)
params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False)

return model, params

def _infer(self, dataset):
model, params = self._load_model_and_params()

return [
model.generate_text(
prompt=instance["source"],
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
params=params,
)
for instance in dataset
]

def _infer_log_probs(self, dataset):
model, params = self._load_model_and_params()

params = {
**params,
"return_options": {
yoavkatz marked this conversation as resolved.
Show resolved Hide resolved
"input_tokens": True,
"generated_tokens": True,
"token_logprobs": True,
"top_n_tokens": 5,
},
}

results = [
model.generate(
prompt=instance["source"],
params=params,
)["results"]
for instance in dataset
]

return [result[0]["generated_tokens"] for result in results]
15 changes: 9 additions & 6 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def metric(sample_refs, sample_preds, sample_task_data):
references=sample_refs,
predictions=sample_preds,
task_data=sample_task_data,
)["score"]
)[score_name]
except Exception as e:
# this happens in edge cases, for example, when the sampling creates a
# sample where all strings are empty and this fails bleu.
Expand Down Expand Up @@ -567,11 +567,12 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato

result = self._compute(references, predictions, task_data)
global_score.update(self._add_score_prefixes_to_score_dict(result))
score_name = global_score["score_name"]
confidence_interval = self.compute_global_confidence_intervals(
references, predictions, task_data, score_name
)
global_score.update(confidence_interval)
score_names = self.ci_scores if self.ci_scores else [global_score["score_name"]]
yoavkatz marked this conversation as resolved.
Show resolved Hide resolved
for score_name in score_names:
confidence_interval = self.compute_global_confidence_intervals(
references, predictions, task_data, score_name
)
global_score.update(confidence_interval)

for instance in instances:
self.update_and_adjust_global_score(instance, global_score)
Expand Down Expand Up @@ -1734,6 +1735,7 @@ class F1Binary(GlobalMetric):
metric = "f1"
single_reference_per_prediction = True
_requirements_list: List[str] = ["sklearn"]
ci_scores = [main_score, "f1_binary_neg"]

def prepare(self):
super().prepare()
Expand Down Expand Up @@ -4261,6 +4263,7 @@ class BinaryMaxF1(F1Binary):
main_score = "max_f1_binary"
single_reference_per_prediction = True
average = None
ci_scores = [main_score, "max_f1_binary_neg"]

def compute(
self,
Expand Down
58 changes: 58 additions & 0 deletions src/unitxt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from difflib import get_close_matches
from typing import Any, Dict

import numpy as np

from .operators import FieldOperator, InstanceFieldOperator


Expand Down Expand Up @@ -277,3 +279,59 @@ def process_value(self, text: Any) -> Any:

except:
return 0


class InferDictsToBinaryLogprobs(FieldOperator):
neg_class_name: str
pos_class_name: str

take_logprobs_from_end: bool = False
num_logprobs_to_take: int = 3
yoavkatz marked this conversation as resolved.
Show resolved Hide resolved
min_probability_mass = 0.0001

def verify(self):
super().verify()
if (
self.neg_class_name.lower() in self.pos_class_name.lower()
or self.pos_class_name.lower() in self.neg_class_name.lower()
):
raise ValueError(
f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
)

def process_value(self, obj: Any) -> Any:
for i in self.get_token_range(obj):
try:
pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
if pos_probs or neg_probs:
sum_probs = sum(pos_probs) + sum(neg_probs)
if sum_probs > self.min_probability_mass:
return sum(pos_probs) / sum_probs
except:
pass
return np.nan

def get_pos_neg_probs(self, pred_dict):
token_logprobs = pred_dict["top_tokens"]

pos_and_neg_probs = []
for class_name in [self.pos_class_name, self.neg_class_name]:
# We need to capture different variants of model behavior and tokenizers, for example with opening space,
# punctuation etc. but avoid longer words that contain the class name.
# For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
name_regex = re.compile(
rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
)
class_probs = [
np.exp(d["logprob"])
for d in token_logprobs
if name_regex.fullmatch(d["text"])
]
pos_and_neg_probs.append(class_probs)
return pos_and_neg_probs

def get_token_range(self, obj: Any) -> range:
n_tokens = min([self.num_logprobs_to_take, len(obj)])
if self.take_logprobs_from_end:
return range(-1, -(n_tokens + 1), -1)
return range(n_tokens)
Loading