Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🥅 Handle empty text for filtered span classification #304

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,16 @@ def run(
Returns:
TokenClassificationResults
"""
error.type_check("<NLP82129006E>", str, text=text)
error.type_check("<NLP01414077E>", float, allow_none=True, threshold=threshold)

if threshold is None:
threshold = self.default_threshold
if not text:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
return TokenClassificationResults(results=[])

token_classification_results = []
if self.classification_task == TextClassificationTask:
# Split document into spans
Expand Down Expand Up @@ -196,10 +204,17 @@ def run_bidi_stream(
Returns:
Iterable[TokenClassificationStreamResult]
"""
error.type_check("<NLP96166348E>", float, allow_none=True, threshold=threshold)
# TODO: For optimization implement window based approach.
if threshold is None:
threshold = self.default_threshold

# Types on the stream are checked later on iteration
if len(text_stream) == 0:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
yield TokenClassificationStreamResult(results=[], processed_index=0)

for span_output in self._stream_span_output(text_stream):
classification_result = self.classifier.run(span_output.text)
results_to_end_of_span = False
Expand Down Expand Up @@ -344,6 +359,7 @@ def __update_spans(token):
return token

for text in text_stream:
error.type_check("<NLP38357927E>", str, text=text)
stream_accumulator += text
# In order to avoid processing all of the spans again, we only
# send out the spans that are not yet finalized in detected_spans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
)
TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS])

# NOTE: First test will test this separately
BOOTSTRAPPED_MODEL = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Modules that already returns token classification for tests
@module(
"44d61711-c64b-4774-a39f-a9f40f1fcff0",
Expand Down Expand Up @@ -120,13 +128,7 @@ def test_bootstrap_run():

def test_bootstrap_run_with_threshold():
"""Check if we can bootstrap span classification models with overriden threshold"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
token_classification_result = model.run(DOCUMENT, threshold=0.0)
token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0.0)
assert isinstance(token_classification_result, TokenClassificationResults)
assert (
len(token_classification_result.results) == 4
Expand Down Expand Up @@ -187,16 +189,17 @@ def test_bootstrap_run_with_token_classification_no_results():
assert len(token_classification_result.results) == 0


def test_bootstrap_run_empty():
"""Check if span classification model can run with empty string"""
token_classification_result = BOOTSTRAPPED_MODEL.run("")
assert isinstance(token_classification_result, TokenClassificationResults)
assert len(token_classification_result.results) == 0


def test_save_load_and_run_model():
"""Check if we can run a saved model successfully"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
with tempfile.TemporaryDirectory() as model_dir:
model.save(model_dir)
BOOTSTRAPPED_MODEL.save(model_dir)
assert os.path.exists(os.path.join(model_dir, "config.yml"))
assert os.path.exists(os.path.join(model_dir, "tokenizer"))
assert os.path.exists(os.path.join(model_dir, "classification"))
Expand All @@ -216,14 +219,9 @@ def test_run_bidi_stream_model():
"""Check if model prediction works as expected for bi-directional stream"""

stream_input = data_model.DataStream.from_iterable(DOCUMENT)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand Down Expand Up @@ -351,14 +349,10 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
works as expected for bi-directional stream"""
doc_stream = (DOCUMENT, " I am another sentence.")
stream_input = data_model.DataStream.from_iterable(doc_stream)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand All @@ -385,22 +379,30 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
assert count == expected_number_of_sentences


def test_run_bidi_stream_empty():
"""Check if span classification model can run with empty string for streaming"""
stream_input = data_model.DataStream.from_iterable("")
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
assert len(result_list) == 1
assert result_list[0].results == []
assert result_list[0].processed_index == 0


def test_run_stream_vs_no_stream():
"""Check if model prediction on stream with multiple sentences/spans
works as expected for bi-directional stream and gives expected span results
as non-stream"""
multiple_sentences = (
"The dragon hoarded gold. The cow ate grass. What is happening? What a day!"
)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Non-stream run
nonstream_classification_result = model.run(multiple_sentences)
nonstream_classification_result = BOOTSTRAPPED_MODEL.run(multiple_sentences)
assert len(nonstream_classification_result.results) == 4
assert nonstream_classification_result.results[0].word == "The dragon hoarded gold."
assert nonstream_classification_result.results[0].start == 0
Expand All @@ -411,7 +413,7 @@ def test_run_stream_vs_no_stream():

# Char-based stream
stream_input = data_model.DataStream.from_iterable(multiple_sentences)
stream_classification_result = model.run_bidi_stream(stream_input)
stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(stream_input)
# Convert to list to more easily check outputs
result_list = list(stream_classification_result)
assert len(result_list) == 4 # one per sentence
Expand All @@ -422,7 +424,9 @@ def test_run_stream_vs_no_stream():

# Chunk-based stream
chunk_stream_input = data_model.DataStream.from_iterable((multiple_sentences,))
chunk_stream_classification_result = model.run_bidi_stream(chunk_stream_input)
chunk_stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
chunk_stream_input
)
result_list = list(chunk_stream_classification_result)
assert len(result_list) == 4 # one per sentence
assert result_list[0].processed_index == 24
Expand Down
Loading