diff --git a/.secrets.baseline b/.secrets.baseline index edaea8c254..1e555ca315 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "^.secrets.baseline$", "lines": null }, - "generated_at": "2024-08-04T13:44:11Z", + "generated_at": "2024-08-07T21:24:30Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -82,7 +82,7 @@ "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_secret": false, "is_verified": false, - "line_number": 1894, + "line_number": 1901, "type": "Hex High Entropy String", "verified_result": null } diff --git a/prepare/processors/processors.py b/prepare/processors/processors.py index 390e62baa5..a0f8b328fe 100644 --- a/prepare/processors/processors.py +++ b/prepare/processors/processors.py @@ -12,6 +12,7 @@ ExtractWithRegex, FirstCharacter, GetStringAfter, + InferDictsToBinaryLogprobs, LiteralEval, LowerCase, LowerCaseTillPunc, @@ -369,3 +370,36 @@ "processors.extract_arena_hard_numerical_judgment", overwrite=True, ) + +add_to_catalog( + SequentialOperator( + steps=[ + InferDictsToBinaryLogprobs( + neg_class_name="No", + pos_class_name="Yes", + num_logprobs_to_take=3, + field="prediction", + process_every_value=False, + ), + ] + ), + "processors.infer_logprobs_to_yes_no_probs", + overwrite=True, +) + +add_to_catalog( + SequentialOperator( + steps=[ + InferDictsToBinaryLogprobs( + neg_class_name="No", + pos_class_name="Yes", + take_logprobs_from_end=True, + num_logprobs_to_take=3, + field="prediction", + process_every_value=False, + ), + ] + ), + "processors.infer_last_token_logprobs_to_yes_no_probs", + overwrite=True, +) diff --git a/src/unitxt/catalog/processors/infer_last_token_logprobs_to_yes_no_probs.json b/src/unitxt/catalog/processors/infer_last_token_logprobs_to_yes_no_probs.json new file mode 100644 index 0000000000..94712cf7b2 --- /dev/null +++ b/src/unitxt/catalog/processors/infer_last_token_logprobs_to_yes_no_probs.json @@ -0,0 +1,14 @@ +{ + "__type__": "sequential_operator", + "steps": [ + { + "__type__": "infer_dicts_to_binary_logprobs", + "neg_class_name": "No", + "pos_class_name": "Yes", + "take_logprobs_from_end": true, + "num_logprobs_to_take": 3, + "field": "prediction", + "process_every_value": false + } + ] +} diff --git a/src/unitxt/catalog/processors/infer_logprobs_to_yes_no_probs.json b/src/unitxt/catalog/processors/infer_logprobs_to_yes_no_probs.json new file mode 100644 index 0000000000..baf5a81675 --- /dev/null +++ b/src/unitxt/catalog/processors/infer_logprobs_to_yes_no_probs.json @@ -0,0 +1,13 @@ +{ + "__type__": "sequential_operator", + "steps": [ + { + "__type__": "infer_dicts_to_binary_logprobs", + "neg_class_name": "No", + "pos_class_name": "Yes", + "num_logprobs_to_take": 3, + "field": "prediction", + "process_every_value": false + } + ] +} diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index ea92d8b9c8..d19b70fc08 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -371,7 +371,10 @@ class WMLInferenceEngineParams(Artifact): class WMLInferenceEngine( - InferenceEngine, WMLInferenceEngineParamsMixin, PackageRequirementsMixin + InferenceEngine, + LogProbInferenceEngine, + WMLInferenceEngineParamsMixin, + PackageRequirementsMixin, ): """Runs inference using ibm-watsonx-ai. @@ -428,7 +431,12 @@ class WMLInferenceEngine( @staticmethod def _read_wml_credentials_from_env() -> Dict[str, str]: credentials = {} - for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]: + + project_or_deployment_var_name = ( + "WML_SPACE_ID" if "WML_SPACE_ID" in os.environ else "WML_PROJECT_ID" + ) + + for env_var_name in ["WML_URL", project_or_deployment_var_name, "WML_APIKEY"]: env_var = os.environ.get(env_var_name) assert env_var, ( f"Error while trying to run 'WMLInferenceEngine'. " @@ -449,7 +457,10 @@ def _initialize_wml_client(self): self.credentials = self._read_wml_credentials_from_env() client = APIClient(credentials=self.credentials) - client.set.default_project(self.credentials["project_id"]) + if "space_id" in self.credentials: + client.set.default_space(self.credentials["space_id"]) + else: + client.set.default_project(self.credentials["project_id"]) return client def prepare(self): @@ -466,7 +477,7 @@ def verify(self): ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time." super().verify() - def _infer(self, dataset): + def _load_model_and_params(self): from ibm_watsonx_ai.foundation_models import ModelInference model = ModelInference( @@ -474,11 +485,51 @@ def _infer(self, dataset): deployment_id=self.deployment_id, api_client=self.client, ) + params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False) + + return model, params + + def _infer(self, dataset): + model, params = self._load_model_and_params() return [ model.generate_text( prompt=instance["source"], - params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False), + params=params, ) for instance in dataset ] + + def _infer_log_probs(self, dataset): + model, params = self._load_model_and_params() + + user_return_options = params.pop("return_options", {}) + # currently this is the only configuration that returns generated logprobs and behaves as expected + logprobs_return_options = { + "input_tokens": True, + "generated_tokens": True, + "token_logprobs": True, + "top_n_tokens": user_return_options.get("top_n_tokens", 5), + } + for key, value in logprobs_return_options.items(): + if key in user_return_options and user_return_options[key] != value: + raise ValueError( + f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' " + f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens " + f"please use '{key}={value}'." + ) + + params = { + **params, + "return_options": logprobs_return_options, + } + + results = [ + model.generate( + prompt=instance["source"], + params=params, + )["results"] + for instance in dataset + ] + + return [result[0]["generated_tokens"] for result in results] diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 3ac41f7cde..1f91ec8179 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -450,7 +450,7 @@ def metric(sample_refs, sample_preds, sample_task_data): references=sample_refs, predictions=sample_preds, task_data=sample_task_data, - )["score"] + )[score_name] except Exception as e: # this happens in edge cases, for example, when the sampling creates a # sample where all strings are empty and this fails bleu. @@ -560,18 +560,24 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato if isinstance(self.main_score, str): instance_score[self.main_score] = no_score_value - instance["score"]["instance"].update( - self._add_score_prefixes_to_score_dict(instance_score) - ) + instance["score"]["instance"].update(instance_score) self._validate_references_and_prediction(references, predictions) result = self._compute(references, predictions, task_data) - global_score.update(self._add_score_prefixes_to_score_dict(result)) - score_name = global_score["score_name"] - confidence_interval = self.compute_global_confidence_intervals( - references, predictions, task_data, score_name - ) - global_score.update(confidence_interval) + global_score.update(result) + + if self.ci_scores: + score_names = [ + self._add_score_prefix(score_name) for score_name in self.ci_scores + ] + else: + score_names = [global_score["score_name"]] + + for score_name in score_names: + confidence_interval = self.compute_global_confidence_intervals( + references, predictions, task_data, score_name + ) + global_score.update(confidence_interval) for instance in instances: self.update_and_adjust_global_score(instance, global_score) @@ -586,7 +592,7 @@ def _compute( result = self.compute(references, predictions, task_data) result["score"] = result[self.main_score] result["score_name"] = self.main_score - return result + return self._add_score_prefixes_to_score_dict(result) @abstractmethod def compute( @@ -1734,6 +1740,7 @@ class F1Binary(GlobalMetric): metric = "f1" single_reference_per_prediction = True _requirements_list: List[str] = ["sklearn"] + ci_scores = [main_score, "f1_binary_neg"] def prepare(self): super().prepare() @@ -4261,6 +4268,7 @@ class BinaryMaxF1(F1Binary): main_score = "max_f1_binary" single_reference_per_prediction = True average = None + ci_scores = [main_score, "max_f1_binary_neg"] def compute( self, diff --git a/src/unitxt/processors.py b/src/unitxt/processors.py index 45de6bce21..f12b6bc138 100644 --- a/src/unitxt/processors.py +++ b/src/unitxt/processors.py @@ -4,6 +4,8 @@ from difflib import get_close_matches from typing import Any, Dict +import numpy as np + from .operators import FieldOperator, InstanceFieldOperator @@ -277,3 +279,59 @@ def process_value(self, text: Any) -> Any: except: return 0 + + +class InferDictsToBinaryLogprobs(FieldOperator): + neg_class_name: str + pos_class_name: str + + take_logprobs_from_end: bool = False + num_logprobs_to_take: int = 3 + min_probability_mass = 0.0001 + + def verify(self): + super().verify() + if ( + self.neg_class_name.lower() in self.pos_class_name.lower() + or self.pos_class_name.lower() in self.neg_class_name.lower() + ): + raise ValueError( + f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}""" + ) + + def process_value(self, obj: Any) -> Any: + for i in self.get_token_range(obj): + try: + pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i]) + if pos_probs or neg_probs: + sum_probs = sum(pos_probs) + sum(neg_probs) + if sum_probs > self.min_probability_mass: + return sum(pos_probs) / sum_probs + except: + pass + return np.nan + + def get_pos_neg_probs(self, pred_dict): + token_logprobs = pred_dict["top_tokens"] + + pos_and_neg_probs = [] + for class_name in [self.pos_class_name, self.neg_class_name]: + # We need to capture different variants of model behavior and tokenizers, for example with opening space, + # punctuation etc. but avoid longer words that contain the class name. + # For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday". + name_regex = re.compile( + rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE + ) + class_probs = [ + np.exp(d["logprob"]) + for d in token_logprobs + if name_regex.fullmatch(d["text"]) + ] + pos_and_neg_probs.append(class_probs) + return pos_and_neg_probs + + def get_token_range(self, obj: Any) -> range: + n_tokens = min([self.num_logprobs_to_take, len(obj)]) + if self.take_logprobs_from_end: + return range(-1, -(n_tokens + 1), -1) + return range(n_tokens)