Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DIA-1523: test coverage for LabelStudioSkill #247

Merged
merged 13 commits into from
Nov 7, 2024
27 changes: 27 additions & 0 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@
logger = logging.getLogger(__name__)


def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name: str, output_field_name: str):
'''
The output format for Labels is:
{
"start": start_idx,
"end": end_idx,
"text": text,
"labels": [label1, label2, ...]
}
Sometimes the model cannot populate "text" correctly, but this can be fixed deterministically.
'''
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
for entity in entities:
corrected_text = text[entity["start"]:entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
return df


def extract_indices(
df,
input_field_name,
Expand Down
18 changes: 10 additions & 8 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import pandas as pd
from typing import Optional, Type
from typing import Type, Iterator
from functools import cached_property
from adala.skills._base import TransformSkill
from pydantic import BaseModel, Field, model_validator
Expand All @@ -12,7 +12,7 @@
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic

from .entity_extraction import extract_indices
from .entity_extraction import extract_indices, validate_output_format_for_ner_tag

logger = logging.getLogger(__name__)

Expand All @@ -29,13 +29,14 @@ class LabelStudioSkill(TransformSkill):

# TODO: implement postprocessing to verify Taxonomy

def has_ner_tag(self) -> Optional[ControlTag]:
def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
#TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == 'Labels':
return tag

yield tag
@model_validator(mode='after')
def validate_response_model(self):

Expand Down Expand Up @@ -79,10 +80,11 @@ async def aapply(
instructions_template=self.instructions,
response_model=ResponseModel,
)
ner_tag = self.has_ner_tag()
if ner_tag:
for ner_tag in self.ner_tags():
input_field_name = ner_tag.objects[0].value.lstrip('$')
output_field_name = ner_tag.name
quote_string_field_name = 'text'
output = extract_indices(pd.concat([input, output], axis=1), input_field_name, output_field_name, quote_string_field_name)
df = pd.concat([input, output], axis=1)
output = validate_output_format_for_ner_tag(df, input_field_name, output_field_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, now we can take out this additional call to validate_output_format_for_ner_tag, but other than that lgtm 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need it, I added the call to EntityExtraction.extract_indices not the standalone extract_indices

output = extract_indices(output, input_field_name, output_field_name, quote_string_field_name)
return output
Loading
Loading