From f50c3b3dfae80b6c26777528ae03e09592871ac3 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 15:41:43 +0100 Subject: [PATCH] Fix slicing such that left and right context are of equal length --- flair/models/relation_classifier_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 7faca14783..dadc17c053 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -256,7 +256,7 @@ def __init__( encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, - max_allowed_tokens_between_entities: int = 50, + max_allowed_tokens_between_entities: int = 20, max_surrounding_context_length: int = 10, **classifierargs, ) -> None: @@ -456,6 +456,8 @@ def _encode_sentence( if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: return None + print(head_idx, tail_idx) + # remove excess tokens left and right of entity pair to make encoded sentence shorter encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( encoded_sentence_tokens, head_idx, tail_idx @@ -480,8 +482,8 @@ def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, padding_amount = self._max_surrounding_context_length begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0 end_slice = ( - end_slice + padding_amount - if end_slice + padding_amount < len(encoded_sentence_tokens) + end_slice + padding_amount + 1 + if end_slice + padding_amount + 1 < len(encoded_sentence_tokens) else len(encoded_sentence_tokens) ) @@ -689,7 +691,9 @@ def predict( ) ) - sentences_with_relation_reference = [item for item in sentences_with_relation_reference if item[0] is not None] + sentences_with_relation_reference = [ + item for item in sentences_with_relation_reference if item[0] is not None + ] encoded_sentences = [x[0] for x in sentences_with_relation_reference] loss = super().predict(