diff --git a/genparse/canonical_tokenization/berglund.py b/genparse/canonical_tokenization/berglund.py new file mode 100644 index 00000000..cb1adf90 --- /dev/null +++ b/genparse/canonical_tokenization/berglund.py @@ -0,0 +1,98 @@ +import collections +import dataclasses +import itertools +from collections.abc import Iterable, Sequence + +State = int +Symbol = int +MergeRule = tuple[Symbol, Symbol] + + +@dataclasses.dataclass +class DFA: + num_states: int + alphabet_size: int + transitions: dict[State, collections.OrderedDict[Symbol, State]] + # Note that in this construction, all states are accept states, so there + # is no explicit set of accept states. + + def new_states(self, n: int) -> range: + lo = self.num_states + self.num_states += n + return range(lo, self.num_states) + + def new_symbol(self) -> Symbol: + a = self.alphabet_size + self.alphabet_size += 1 + return a + + def get_state_to(self, state_from: State, symbol: Symbol) -> State | None: + d = self.transitions.get(state_from) + if d is not None: + return d.get(symbol) + + def get_transitions_from_state(self, state_from: State) -> Iterable[Symbol, State]: + d = self.transitions.get(state_from) + if d is not None: + return d.items() + else: + return () + + def set_transition(self, state_from: State, symbol: Symbol, state_to: State) -> None: + if state_from not in self.transitions: + self.transitions[state_from] = collections.OrderedDict() + self.transitions[state_from][symbol] = state_to + + +def construct_base_token_dfa(alphabet_size: int) -> DFA: + return DFA( + num_states=1, + alphabet_size=alphabet_size, + transitions={0: collections.OrderedDict.fromkeys(range(alphabet_size), 0)}, + ) + + +def merge_rule_into_token_dfa(dfa: DFA, rule: MergeRule) -> None: + # This implements Algorithm 2 from https://arxiv.org/pdf/2405.07671 + u, v = rule + uv = dfa.new_symbol() + # Use OrderedDict to ensure deterministic iteration order. + S2 = collections.OrderedDict() + for s1 in range(dfa.num_states): + s2 = dfa.get_state_to(s1, u) + if s2 is not None: + s3 = dfa.get_state_to(s2, v) + if s3 is not None: + dfa.set_transition(s1, uv, s3) + S2[s2] = True + fresh = dfa.new_states(len(S2)) + excluded = [v] + if u == v: + excluded.append(uv) + for s2, fresh_s2 in zip(S2, fresh): + for alpha, state_to in dfa.get_transitions_from_state(s2): + if alpha not in excluded: + dfa.set_transition(fresh_s2, alpha, state_to) + state_to_fresh = dict(zip(S2, fresh)) + for q in range(dfa.num_states): + state_to = dfa.get_state_to(q, u) + if state_to is not None: + fresh_state_to = state_to_fresh.get(state_to) + if fresh_state_to is not None: + dfa.set_transition(q, u, fresh_state_to) + + +def construct_token_dfa(alphabet_size: int, dictionary: Iterable[MergeRule]) -> DFA: + dfa = construct_base_token_dfa(alphabet_size) + for rule in dictionary: + merge_rule_into_token_dfa(dfa, rule) + # TODO Add trimming? Find all states not reachable from start state. + # This can probably be done during construction without needing to rescan + # the automaton from scratch every time. + return dfa + + +def get_int_mapping( + alphabet: Iterable[str], dictionary: Iterable[tuple[str, str]] +) -> Iterable[str]: + return itertools.chain(alphabet, (u + v for u, v in dictionary)) diff --git a/tests/test_berglund_bpe_dfa.py b/tests/test_berglund_bpe_dfa.py new file mode 100644 index 00000000..f4224ffd --- /dev/null +++ b/tests/test_berglund_bpe_dfa.py @@ -0,0 +1,89 @@ +from genparse.canonical_tokenization.berglund import ( + DFA, + construct_base_token_dfa, + merge_rule_into_token_dfa, + get_int_mapping, +) + + +def are_isomorphic(dfa1, dfa2): + if not ( + dfa1.num_states == dfa2.num_states and dfa1.alphabet_size == dfa2.alphabet_size + ): + return False + agenda = [0] + state_mapping = {0: 0} + while agenda: + q1 = agenda.pop() + for a in range(dfa1.alphabet_size): + r1 = dfa1.get_state_to(q1, a) + r2 = dfa2.get_state_to(state_mapping[q1], a) + if r1 is None: + if r2 is not None: + return False + else: + if r2 is None: + return False + else: + if r1 in state_mapping: + if r2 != state_mapping[r1]: + return False + else: + state_mapping[r1] = r2 + agenda.append(r1) + return True + + +def test_example_2(): + # This tests Example 2 of https://arxiv.org/abs/2405.07671 + + alphabet = ['a', 'b'] + dictionary = [('a', 'a'), ('b', 'a')] + + int_to_str = get_int_mapping(alphabet, dictionary) + str_to_int = {s: i for i, s in enumerate(int_to_str)} + dictionary_as_int = [(str_to_int[u], str_to_int[v]) for u, v in dictionary] + + def construct_dfa(num_states, alphabet_size, transitions): + M = DFA(num_states=num_states, alphabet_size=alphabet_size, transitions={}) + for q, a, r in transitions: + M.set_transition(q, str_to_int[a], r) + return M + + M = construct_base_token_dfa(len(alphabet)) + N = construct_dfa( + num_states=1, alphabet_size=2, transitions=[(0, a, 0) for a in alphabet] + ) + assert are_isomorphic(M, N) + + merge_rule_into_token_dfa(M, dictionary_as_int[0]) + N = construct_dfa( + num_states=2, + alphabet_size=3, + transitions=[(0, 'aa', 0), (0, 'b', 0), (0, 'a', 1), (1, 'b', 0)], + ) + assert are_isomorphic(M, N) + + merge_rule_into_token_dfa(M, dictionary_as_int[1]) + N = construct_dfa( + num_states=3, + alphabet_size=4, + transitions=[ + (0, 'aa', 0), + (0, 'b', 2), + (0, 'a', 1), + (0, 'ba', 1), + (1, 'ba', 1), + (1, 'b', 2), + (2, 'b', 2), + (2, 'aa', 0), + (2, 'ba', 1), + ], + ) + assert are_isomorphic(M, N) + + +if __name__ == '__main__': + from arsenal import testing_framework + + testing_framework(globals())