-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'refs/remotes/origin/main'
- Loading branch information
Showing
5 changed files
with
343 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,91 @@ | ||
import collections | ||
import dataclasses | ||
import itertools | ||
from collections.abc import Iterable, Sequence | ||
|
||
State = int | ||
Symbol = int | ||
MergeRule = tuple[Symbol, Symbol] | ||
# A merge rule consists of (u, v, uv), where uv is the index of the symbol for | ||
# the concatenation of u and v. | ||
MergeRule = tuple[Symbol, Symbol, Symbol] | ||
|
||
|
||
@dataclasses.dataclass | ||
class DFA: | ||
num_states: int | ||
alphabet_size: int | ||
transitions: dict[State, collections.OrderedDict[Symbol, State]] | ||
class TokenDFA: | ||
# The start state is always implicitly state 0. | ||
# Note that in this construction, all states are accept states, so there | ||
# is no explicit set of accept states. | ||
num_states: int | ||
transitions: dict[State, dict[Symbol, State]] | ||
|
||
@staticmethod | ||
def from_dictionary( | ||
base_alphabet: Iterable[Symbol], dictionary: Iterable[MergeRule] | ||
) -> 'TokenDFA': | ||
dfa = TokenDFA.from_base_alphabet(base_alphabet) | ||
for rule in dictionary: | ||
dfa.merge_rule(rule) | ||
# TODO Add trimming? Find all states not reachable from start state. | ||
# This can probably be done during construction without needing to rescan | ||
# the automaton from scratch every time. | ||
return dfa | ||
|
||
@staticmethod | ||
def from_base_alphabet(base_alphabet: Iterable[Symbol]) -> 'TokenDFA': | ||
return TokenDFA(num_states=1, transitions={0: {a: 0 for a in base_alphabet}}) | ||
|
||
def new_states(self, n: int) -> range: | ||
lo = self.num_states | ||
self.num_states += n | ||
return range(lo, self.num_states) | ||
|
||
def new_symbol(self) -> Symbol: | ||
a = self.alphabet_size | ||
self.alphabet_size += 1 | ||
return a | ||
|
||
def get_state_to(self, state_from: State, symbol: Symbol) -> State | None: | ||
d = self.transitions.get(state_from) | ||
if d is not None: | ||
return d.get(symbol) | ||
|
||
def get_transitions_from_state(self, state_from: State) -> Iterable[Symbol, State]: | ||
def get_transitions_from_state( | ||
self, state_from: State | ||
) -> Iterable[tuple[Symbol, State]]: | ||
d = self.transitions.get(state_from) | ||
if d is not None: | ||
return d.items() | ||
else: | ||
return () | ||
|
||
def get_transitions(self) -> Iterable[tuple[State, Symbol, State]]: | ||
for state_from, transitions_from_state in self.transitions.items(): | ||
for symbol, state_to in transitions_from_state.items(): | ||
yield state_from, symbol, state_to | ||
|
||
def set_transition(self, state_from: State, symbol: Symbol, state_to: State) -> None: | ||
if state_from not in self.transitions: | ||
self.transitions[state_from] = collections.OrderedDict() | ||
self.transitions[state_from] = {} | ||
self.transitions[state_from][symbol] = state_to | ||
|
||
|
||
def construct_base_token_dfa(alphabet_size: int) -> DFA: | ||
return DFA( | ||
num_states=1, | ||
alphabet_size=alphabet_size, | ||
transitions={0: collections.OrderedDict.fromkeys(range(alphabet_size), 0)}, | ||
) | ||
|
||
|
||
def merge_rule_into_token_dfa(dfa: DFA, rule: MergeRule) -> None: | ||
# This implements Algorithm 2 from https://arxiv.org/pdf/2405.07671 | ||
u, v = rule | ||
uv = dfa.new_symbol() | ||
# Use OrderedDict to ensure deterministic iteration order. | ||
S2 = collections.OrderedDict() | ||
for s1 in range(dfa.num_states): | ||
s2 = dfa.get_state_to(s1, u) | ||
if s2 is not None: | ||
s3 = dfa.get_state_to(s2, v) | ||
if s3 is not None: | ||
dfa.set_transition(s1, uv, s3) | ||
S2[s2] = True | ||
fresh = dfa.new_states(len(S2)) | ||
excluded = [v] | ||
if u == v: | ||
excluded.append(uv) | ||
for s2, fresh_s2 in zip(S2, fresh): | ||
for alpha, state_to in dfa.get_transitions_from_state(s2): | ||
if alpha not in excluded: | ||
dfa.set_transition(fresh_s2, alpha, state_to) | ||
state_to_fresh = dict(zip(S2, fresh)) | ||
for q in range(dfa.num_states): | ||
state_to = dfa.get_state_to(q, u) | ||
if state_to is not None: | ||
fresh_state_to = state_to_fresh.get(state_to) | ||
if fresh_state_to is not None: | ||
dfa.set_transition(q, u, fresh_state_to) | ||
|
||
|
||
def construct_token_dfa(alphabet_size: int, dictionary: Iterable[MergeRule]) -> DFA: | ||
dfa = construct_base_token_dfa(alphabet_size) | ||
for rule in dictionary: | ||
merge_rule_into_token_dfa(dfa, rule) | ||
# TODO Add trimming? Find all states not reachable from start state. | ||
# This can probably be done during construction without needing to rescan | ||
# the automaton from scratch every time. | ||
return dfa | ||
|
||
|
||
def get_int_mapping( | ||
alphabet: Iterable[str], dictionary: Iterable[tuple[str, str]] | ||
) -> Iterable[str]: | ||
return itertools.chain(alphabet, (u + v for u, v in dictionary)) | ||
def merge_rule(self, rule: MergeRule) -> None: | ||
# This implements Algorithm 2 of https://arxiv.org/pdf/2405.07671 | ||
u, v, uv = rule | ||
# Use dict to ensure deterministic iteration order. | ||
S2 = {} | ||
for s1 in range(self.num_states): | ||
s2 = self.get_state_to(s1, u) | ||
if s2 is not None: | ||
s3 = self.get_state_to(s2, v) | ||
if s3 is not None: | ||
self.set_transition(s1, uv, s3) | ||
S2[s2] = True | ||
fresh = self.new_states(len(S2)) | ||
excluded = [v] | ||
if u == v: | ||
excluded.append(uv) | ||
for s2, fresh_s2 in zip(S2, fresh): | ||
for alpha, state_to in self.get_transitions_from_state(s2): | ||
if alpha not in excluded: | ||
self.set_transition(fresh_s2, alpha, state_to) | ||
state_to_fresh = dict(zip(S2, fresh)) | ||
for q in range(self.num_states): | ||
state_to = self.get_state_to(q, u) | ||
if state_to is not None: | ||
fresh_state_to = state_to_fresh.get(state_to) | ||
if fresh_state_to is not None: | ||
self.set_transition(q, u, fresh_state_to) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import dataclasses | ||
import json | ||
import math | ||
import pathlib | ||
|
||
import torch | ||
|
||
from genparse.canonical_tokenization.canonicalizer import ( | ||
NonCanonicalTokenizationError, | ||
CanonicalizerIterator, | ||
Canonicalizer, | ||
) | ||
|
||
|
||
@dataclasses.dataclass | ||
class BPECanonicalizerIterator(CanonicalizerIterator): | ||
parent: 'BPECanonicalizer' | ||
state: int | ||
|
||
def next(self, token): | ||
new_state = self.parent.transitions.get((self.state, token)) | ||
if new_state is None: | ||
raise NonCanonicalTokenizationError | ||
self.state = new_state | ||
|
||
def mask(self): | ||
return self.parent.mask_tensor[self.state] | ||
|
||
|
||
@dataclasses.dataclass | ||
class BPECanonicalizer(Canonicalizer): | ||
transitions: dict[tuple[int, int], int] | ||
mask_tensor: torch.Tensor | ||
_eos_token_id: int | ||
|
||
def iterator(self): | ||
return BPECanonicalizerIterator(self, 0) | ||
|
||
def eos_token_id(self): | ||
return self._eos_token_id | ||
|
||
@staticmethod | ||
def from_file( | ||
path: pathlib.Path, dtype: torch.dtype, device: torch.device | ||
) -> 'BPECanonicalizer': | ||
data = torch.load(path) | ||
eos_token_id = data['eos_token_id'] | ||
transition_list = data['transitions'].tolist() | ||
# Index the transitions by source state and symbol. | ||
transitions = { | ||
(state_from, symbol): state_to | ||
for state_from, symbol, state_to in transition_list | ||
} | ||
# Precompute the mask tensors. | ||
mask_tensor = torch.full( | ||
(data['num_states'], data['vocabulary_size']), | ||
-math.inf, | ||
dtype=dtype, | ||
device=device, | ||
) | ||
# Conveniently, the keys of the transitions dict are the coordinates | ||
# of the entries that should be 0. | ||
index_tensor = torch.tensor(list(transitions.keys()), device=device) | ||
mask_tensor[torch.unbind(index_tensor, dim=1)] = 0 | ||
# All states are accept states, so EOS is allowed at every state. | ||
mask_tensor[:, eos_token_id] = 0 | ||
return BPECanonicalizer(transitions, mask_tensor, eos_token_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import math | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
|
||
|
||
class NonCanonicalTokenizationError(ValueError): | ||
pass | ||
|
||
|
||
class CanonicalizerIterator: | ||
def next(self, token: int) -> None: | ||
"""Read a token and update the state of this iterator accordingly. | ||
Raise NonCanonicalTokenizationError if this token would result in a | ||
non-canonical token sequence.""" | ||
# Raises if the token sequence is invalid | ||
raise NotImplementedError | ||
|
||
def mask(self) -> torch.Tensor: | ||
"""Return a mask representing which tokens may validly be generated | ||
next given the current state of this iterator. | ||
The result is a tensor with 0 at indexes for valid tokens, and -inf at | ||
indexes for invalid tokens. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class Canonicalizer: | ||
def masks(self, tokens: Iterable[int]) -> Iterable[torch.Tensor]: | ||
"""Read n tokens and generate a sequence of n+1 masks. | ||
The eos_token index is needed to check that the token sequence is | ||
allowed to end when it does. | ||
""" | ||
it = self.iterator() | ||
mask = it.mask() | ||
yield mask | ||
for token in tokens: | ||
it.next(token) | ||
mask = it.mask() | ||
yield mask | ||
if mask[self.eos_token_id()].item() == -math.inf: | ||
raise NonCanonicalTokenizationError('this token sequence cannot end with EOS') | ||
|
||
def iterator(self) -> CanonicalizerIterator: | ||
"""Return a CanonicalizerIterator that can be used to compute masks | ||
iteratively, e.g., during decoding.""" | ||
raise NotImplementedError | ||
|
||
def eos_token_id(self) -> int: | ||
"""Get the id of the EOS token.""" | ||
raise NotImplementedError |
Oops, something went wrong.