From 97e5021ed635d48c41e45b6c5fc2d516366270c7 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Sun, 10 Nov 2024 08:11:49 +0100 Subject: [PATCH] Add `StringSplitAggregator` (#290) --- inseq/data/aggregator.py | 144 ++++++++++++++++++++++++++++++++++++++- pyproject.toml | 3 +- requirements-dev.txt | 6 +- 3 files changed, 148 insertions(+), 5 deletions(-) diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index 95b1030..2e0301b 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -1,7 +1,9 @@ import logging +import re from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, TypeVar +from enum import Enum +from typing import TYPE_CHECKING, Literal, TypeVar import torch @@ -541,7 +543,7 @@ def validate_spans(cls, span_sequence: list[TokenWithId], spans: IndexSpan | Non prev_span_max = -1 for span in spans: assert len(span) == 2, f"Spans must contain two indexes, got {spans}" - assert span[1] > span[0] + 1, f"Spans must be non-empty, got {spans}" + assert span[1] >= span[0] + 1, f"Spans must be non-empty, got {spans}" assert ( span[0] >= prev_span_max ), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}" @@ -722,6 +724,144 @@ def get_spans(tokens: list[TokenWithId], special_chars: str | tuple[str, ...], i return spans +class StringSplitAggregator(ContiguousSpanAggregator): + """Aggregates contiguous tokens using specified strings as separators. + + Args: + attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object to aggregate. + aggregate_fn (:obj:`Callable`, optional): Function to aggregate over the subwords. + Defaults to the highest absolute value score across the aggregated span, with original sign + preserved (e.g. [0.3, -0.7, 0.1] -> -0.7). + aggregate_source (bool, optional): Whether to aggregate over the source sequence. Defaults to True. + aggregate_target (bool, optional): Whether to aggregate over the target sequence. Defaults to True. + split_pattern (str): Regular expression pattern used to split the sequences. + split_mode (str, optional): Treatment for split tokens. If "single", these are kept separate from previous and + following tokens. If "start", they are concatenated with following tokens. If "end", they are concatenated + to previous tokens. Defaults to "single". + """ + + aggregator_name = "split" + + class SplitStrategy(Enum): + SINGLE = "single" + START = "start" + END = "end" + + @classmethod + def aggregate( + cls, + attr: "FeatureAttributionSequenceOutput", + aggregate_source: bool = True, + aggregate_target: bool = True, + split_pattern: str = None, + split_mode: Literal["single", "start", "end"] = SplitStrategy.SINGLE.value, + **kwargs, + ): + source_spans = [] + target_spans = [] + if split_pattern is None: + raise ValueError("split_pattern is None. Provide a valid regular expression pattern to split the string.") + if aggregate_source: + source_spans = cls.get_spans(attr.source, split_pattern, split_mode) + if aggregate_target: + target_spans = cls.get_spans(attr.target, split_pattern, split_mode) + return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs) + + @classmethod + def get_spans( + cls, + tokens: list[TokenWithId], + split_pattern: str, + split_mode: Literal["single", "start", "end"] = SplitStrategy.SINGLE.value, + ) -> list[tuple[int, int]]: + full_text = "".join(t.token for t in tokens) + curr_idx = 0 + token_spans = [] + + # Generate token spans + for tok in tokens: + token_spans.append((curr_idx, curr_idx + len(tok.token))) + curr_idx += len(tok.token) + + # Find all matches for the given pattern + matches = list(re.finditer(split_pattern, full_text)) + if not matches: + return [] + matches_spans = [(m.start(), m.end()) for m in matches] + + # Create matches_tokens list + matches_tokens = [] + for start, end in matches_spans: + token_start = next((i for i, (ts, te) in enumerate(token_spans) if ts <= start < te), None) + token_end = next((i for i, (ts, te) in enumerate(token_spans) if ts < end <= te), None) + 1 + if token_start is not None and token_end is not None: + matches_tokens.append((token_start, token_end)) + + # Remove duplicate spans + seen_tokens = set() + matches_tokens = [m for m in matches_tokens if not (m in seen_tokens or seen_tokens.add(m))] + + # If overlapping token spans are found, split them + non_overlapping_matches = [] + for curr_idx, (start, end) in enumerate(matches_tokens): + curr_start, curr_end = start, end + if len(matches_tokens) > curr_idx + 1 and end > matches_tokens[curr_idx + 1][0]: + curr_end = matches_tokens[curr_idx + 1][0] + if curr_idx > 0 and start < non_overlapping_matches[-1][1]: + curr_start = non_overlapping_matches[-1][1] + non_overlapping_matches.append((curr_start, curr_end)) + if curr_end != end and end < matches_tokens[curr_idx + 1][1]: + non_overlapping_matches.append((curr_end, end)) + matches_tokens = non_overlapping_matches + + # Fill missing spans + aggregate_spans = [] + matched_span = [] + if matches_tokens[0][0] != 0: + aggregate_spans.append((0, matches_tokens[0][0])) + matched_span.append(False) + for i in range(len(matches_tokens) - 1): + aggregate_spans.append(matches_tokens[i]) + matched_span.append(True) + if matches_tokens[i][1] != matches_tokens[i + 1][0]: + aggregate_spans.append((matches_tokens[i][1], matches_tokens[i + 1][0])) + matched_span.append(False) + aggregate_spans.append(matches_tokens[-1]) + matched_span.append(True) + if matches_tokens[-1][1] != len(tokens): + aggregate_spans.append((matches_tokens[-1][1], len(tokens))) + matched_span.append(False) + + # Create aggregate spans based on the split strategy + if split_mode == cls.SplitStrategy.SINGLE.value: + return aggregate_spans + elif split_mode in (cls.SplitStrategy.START.value, cls.SplitStrategy.END.value): + merge_aggregate_spans = [] + curr_span_start = 0 + + # If the strategy is "start", all match spans are concatenated to their following non-match spans + # If the strategy is "end", all match spans are concatenated to their preceding non-match spans + # Example: + # aggregate_spans = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 9)] + # matched_span = [False, True, True, False, True, False] + # Start strategy: [(0, 1), (1, 2), (2, 4), (4, 9)] + # End strategy: [(0, 2), (2, 3), (3, 5), (5, 9)] + for (start, end), is_match in zip(aggregate_spans, matched_span, strict=False): + if is_match: + if split_mode == cls.SplitStrategy.START.value and start != curr_span_start: + merge_aggregate_spans.append((curr_span_start, start)) + curr_span_start = start + elif split_mode == cls.SplitStrategy.END.value: + merge_aggregate_spans.append((curr_span_start, end)) + curr_span_start = end + if curr_span_start != aggregate_spans[-1][1]: + merge_aggregate_spans.append((curr_span_start, aggregate_spans[-1][1])) + + return merge_aggregate_spans + else: + raise ValueError("Invalid split strategy: must be one of 'single', 'start', 'end'") + + class PairAggregator(SequenceAttributionAggregator): """Aggregates two FeatureAttributionSequenceOutput object into a single one containing the diff. diff --git a/pyproject.toml b/pyproject.toml index 363c3c0..9279789 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ lint = [ "pytest>=7.2.0", "pytest-cov>=4.0.0", "pytest-xdist>=3.5.0", - "ruff>=0.2.0" + "ruff>=0.2.0", + "virtualenv>=20.26.6" ] sklearn = [ "scikit-learn>=1.5.1", diff --git a/requirements-dev.txt b/requirements-dev.txt index e826e32..5770ce7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -426,8 +426,10 @@ urllib3==2.2.2 # via # requests # safety -virtualenv==20.25.0 - # via pre-commit +virtualenv==20.27.1 + # via + # inseq (pyproject.toml) + # pre-commit wcwidth==0.2.13 # via prompt-toolkit widgetsnbextension==4.0.10