Skip to content

Commit

Permalink
Update error checking and documentation of processors (#1325)
Browse files Browse the repository at this point in the history
* Improved error messages and doc of post processors

Signed-off-by: Yoav Katz <[email protected]>

* improved doc

Signed-off-by: Yoav Katz <[email protected]>

---------

Signed-off-by: Yoav Katz <[email protected]>
Co-authored-by: Elron Bandel <[email protected]>
  • Loading branch information
yoavkatz and elronbandel authored Nov 3, 2024
1 parent d24dd4c commit 582d96f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 6 deletions.
30 changes: 27 additions & 3 deletions docs/docs/adding_template.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ Post Processors
---------------

The template also defines the post processing steps applied to the output predictions of the model before they are passed to the :ref:`Metrics <metric>`.
The post processors applied both to the model prediction and to the references.
Typically, the post processors applied both to the model prediction and to the references.
For example, we could use the ``processors.lower_case`` processor to lowercase both the model predictions and references,
so the metric computation will ignore case. When needed, It is possible to add post processors that applied only to the output of the model and not the references or vice versa.

so the metric computation will ignore case.
.. code-block:: python
from unitxt.templates import InputOutputTemplate
template = InputOutputTemplate(
instruction="In the following task, you translate a {text_type}.",
input_format="Translate this {text_type} from {source_language} to {target_language}: {text}.",
Expand All @@ -99,6 +99,30 @@ The reason the post processors are set in the template, is because different tem
For example, one template may prompt the model to answer ``Yes`` or ``No`` while another
template may prompt the model to answer ``True`` or ``False``. Both can use different post processors to convert them to standard model prediction of `0` or `1`.

Post processors implemented as operators. Usually they are implemented as fields operators that are applied to the ``prediction``
and ``references``` fields. When needed, It is possible to add post processors that are applied only to the prediction of the model and not the references or vice versa.
Here we see how we can lowercase only the model prediction.

.. code-block:: python
from unitxt.processors import PostProcess
from unitxt.operators import FieldOperator
class Lower(FieldOperator):
def process_value(self, text: Any) -> Any:
return text.lower()
from unitxt.templates import InputOutputTemplate
template = InputOutputTemplate(
instruction="In the following task, you translate a {text_type}.",
input_format="Translate this {text_type} from {source_language} to {target_language}: {text}.",
target_prefix="Translation: ",
output_format='{translation}',
postprocessors= [
PostProcess(Lower(),process_references=False)
]
)
You can see all the available predefined post processors in the catalog (:ref:`Processor <processors>`.)

Templates for Special Cases
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Documentation:
HUGGINGFACE_METRICS = "docs/adding_metric.html#adding-a-hugginface-metric"
ADDING_TASK = "docs/adding_task.html"
ADDING_TEMPLATE = "docs/adding_template.html"
POST_PROCESSORS = "docs/adding_template.html#post-processors"
MULTIPLE_METRICS_OUTPUTS = (
"docs/adding_metric.html#metric-outputs-with-multiple-metrics"
)
Expand Down
7 changes: 7 additions & 0 deletions src/unitxt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import numpy as np

from .deprecation_utils import deprecation
from .error_utils import Documentation, UnitxtError
from .operator import MultiStreamOperator
from .operators import FieldOperator, InstanceFieldOperator
from .settings_utils import get_constants
from .type_utils import isoftype

constants = get_constants()

Expand All @@ -23,6 +25,11 @@ class PostProcess(MultiStreamOperator):

def prepare(self):
super().prepare()
if not isoftype(self.operator, InstanceFieldOperator):
raise UnitxtError(
f"PostProcess requires operator field to be of type InstanceFieldOperator. Got object of type <{type(self.operator).__name__}>.",
Documentation.POST_PROCESSORS,
)
self.prediction_operator = copy.copy(self.operator)
self.prediction_operator.field = "prediction"
self.references_operator = copy.copy(self.operator)
Expand Down
4 changes: 3 additions & 1 deletion src/unitxt/split_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
new_streams = {}
for key, val in mapping.items():
if key not in input_streams:
raise ValueError("Wrong stream name")
raise ValueError(
f"Stream '{key}' is not in input_streams '{input_streams.keys()}'"
)
new_streams[val] = input_streams.pop(key)
return {**input_streams, **new_streams}

Expand Down
10 changes: 8 additions & 2 deletions src/unitxt/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .dataclass import NonPositionalField
from .dict_utils import dict_set
from .error_utils import Documentation, UnitxtError
from .operator import InstanceOperator
from .operator import InstanceOperator, Operator
from .random_utils import new_random_generator
from .serializers import (
DialogSerializer,
Expand All @@ -21,7 +21,7 @@
VideoSerializer,
)
from .settings_utils import get_constants
from .type_utils import isoftype
from .type_utils import isoftype, to_type_string

constants = get_constants()

Expand Down Expand Up @@ -68,6 +68,12 @@ class Template(InstanceOperator):
)
)

def verify(self):
super().verify()
assert isoftype(
self.postprocessors, List[Union[Operator, str]]
), f"The template post processors field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'."

def input_fields_to_instruction_and_target_prefix(self, input_fields):
instruction = self.apply_formatting(
input_fields, "input field", self.instruction, "instruction"
Expand Down

0 comments on commit 582d96f

Please sign in to comment.