diff --git a/caikit_nlp/modules/token_classification/filtered_span_classification.py b/caikit_nlp/modules/token_classification/filtered_span_classification.py index 15ea432a..55733df1 100644 --- a/caikit_nlp/modules/token_classification/filtered_span_classification.py +++ b/caikit_nlp/modules/token_classification/filtered_span_classification.py @@ -136,8 +136,16 @@ def run( Returns: TokenClassificationResults """ + error.type_check("", str, text=text) + error.type_check("", float, allow_none=True, threshold=threshold) + if threshold is None: threshold = self.default_threshold + if not text: + # Allow empty text case to fall through - some tokenizers or + # classifiers may error on this + return TokenClassificationResults(results=[]) + token_classification_results = [] if self.classification_task == TextClassificationTask: # Split document into spans @@ -196,10 +204,17 @@ def run_bidi_stream( Returns: Iterable[TokenClassificationStreamResult] """ + error.type_check("", float, allow_none=True, threshold=threshold) # TODO: For optimization implement window based approach. if threshold is None: threshold = self.default_threshold + # Types on the stream are checked later on iteration + if len(text_stream) == 0: + # Allow empty text case to fall through - some tokenizers or + # classifiers may error on this + yield TokenClassificationStreamResult(results=[], processed_index=0) + for span_output in self._stream_span_output(text_stream): classification_result = self.classifier.run(span_output.text) results_to_end_of_span = False @@ -344,6 +359,7 @@ def __update_spans(token): return token for text in text_stream: + error.type_check("", str, text=text) stream_accumulator += text # In order to avoid processing all of the spans again, we only # send out the spans that are not yet finalized in detected_spans diff --git a/tests/modules/token_classification/test_filtered_span_classification.py b/tests/modules/token_classification/test_filtered_span_classification.py index ce20c5cb..c8f14036 100644 --- a/tests/modules/token_classification/test_filtered_span_classification.py +++ b/tests/modules/token_classification/test_filtered_span_classification.py @@ -51,6 +51,14 @@ ) TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS]) +# NOTE: First test will test this separately +BOOTSTRAPPED_MODEL = FilteredSpanClassification.bootstrap( + lang="en", + tokenizer=SENTENCE_TOKENIZER, + classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, + default_threshold=0.5, +) + # Modules that already returns token classification for tests @module( "44d61711-c64b-4774-a39f-a9f40f1fcff0", @@ -120,13 +128,7 @@ def test_bootstrap_run(): def test_bootstrap_run_with_threshold(): """Check if we can bootstrap span classification models with overriden threshold""" - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) - token_classification_result = model.run(DOCUMENT, threshold=0.0) + token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0.0) assert isinstance(token_classification_result, TokenClassificationResults) assert ( len(token_classification_result.results) == 4 @@ -187,16 +189,17 @@ def test_bootstrap_run_with_token_classification_no_results(): assert len(token_classification_result.results) == 0 +def test_bootstrap_run_empty(): + """Check if span classification model can run with empty string""" + token_classification_result = BOOTSTRAPPED_MODEL.run("") + assert isinstance(token_classification_result, TokenClassificationResults) + assert len(token_classification_result.results) == 0 + + def test_save_load_and_run_model(): """Check if we can run a saved model successfully""" - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) with tempfile.TemporaryDirectory() as model_dir: - model.save(model_dir) + BOOTSTRAPPED_MODEL.save(model_dir) assert os.path.exists(os.path.join(model_dir, "config.yml")) assert os.path.exists(os.path.join(model_dir, "tokenizer")) assert os.path.exists(os.path.join(model_dir, "classification")) @@ -216,14 +219,9 @@ def test_run_bidi_stream_model(): """Check if model prediction works as expected for bi-directional stream""" stream_input = data_model.DataStream.from_iterable(DOCUMENT) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input ) - - streaming_token_classification_result = model.run_bidi_stream(stream_input) assert isinstance(streaming_token_classification_result, Iterable) # Convert to list to more easily check outputs result_list = list(streaming_token_classification_result) @@ -351,14 +349,10 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk(): works as expected for bi-directional stream""" doc_stream = (DOCUMENT, " I am another sentence.") stream_input = data_model.DataStream.from_iterable(doc_stream) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) - streaming_token_classification_result = model.run_bidi_stream(stream_input) + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input + ) assert isinstance(streaming_token_classification_result, Iterable) # Convert to list to more easily check outputs result_list = list(streaming_token_classification_result) @@ -385,6 +379,20 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk(): assert count == expected_number_of_sentences +def test_run_bidi_stream_empty(): + """Check if span classification model can run with empty string for streaming""" + stream_input = data_model.DataStream.from_iterable("") + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input + ) + assert isinstance(streaming_token_classification_result, Iterable) + # Convert to list to more easily check outputs + result_list = list(streaming_token_classification_result) + assert len(result_list) == 1 + assert result_list[0].results == [] + assert result_list[0].processed_index == 0 + + def test_run_stream_vs_no_stream(): """Check if model prediction on stream with multiple sentences/spans works as expected for bi-directional stream and gives expected span results @@ -392,15 +400,9 @@ def test_run_stream_vs_no_stream(): multiple_sentences = ( "The dragon hoarded gold. The cow ate grass. What is happening? What a day!" ) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) # Non-stream run - nonstream_classification_result = model.run(multiple_sentences) + nonstream_classification_result = BOOTSTRAPPED_MODEL.run(multiple_sentences) assert len(nonstream_classification_result.results) == 4 assert nonstream_classification_result.results[0].word == "The dragon hoarded gold." assert nonstream_classification_result.results[0].start == 0 @@ -411,7 +413,7 @@ def test_run_stream_vs_no_stream(): # Char-based stream stream_input = data_model.DataStream.from_iterable(multiple_sentences) - stream_classification_result = model.run_bidi_stream(stream_input) + stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(stream_input) # Convert to list to more easily check outputs result_list = list(stream_classification_result) assert len(result_list) == 4 # one per sentence @@ -422,7 +424,9 @@ def test_run_stream_vs_no_stream(): # Chunk-based stream chunk_stream_input = data_model.DataStream.from_iterable((multiple_sentences,)) - chunk_stream_classification_result = model.run_bidi_stream(chunk_stream_input) + chunk_stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + chunk_stream_input + ) result_list = list(chunk_stream_classification_result) assert len(result_list) == 4 # one per sentence assert result_list[0].processed_index == 24