Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 29, 2024
2 parents 8e7b805 + be7dfd2 commit e374ad6
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 92 deletions.
125 changes: 59 additions & 66 deletions genparse/canonical_tokenization/berglund.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,91 @@
import collections
import dataclasses
import itertools
from collections.abc import Iterable, Sequence

State = int
Symbol = int
MergeRule = tuple[Symbol, Symbol]
# A merge rule consists of (u, v, uv), where uv is the index of the symbol for
# the concatenation of u and v.
MergeRule = tuple[Symbol, Symbol, Symbol]


@dataclasses.dataclass
class DFA:
num_states: int
alphabet_size: int
transitions: dict[State, collections.OrderedDict[Symbol, State]]
class TokenDFA:
# The start state is always implicitly state 0.
# Note that in this construction, all states are accept states, so there
# is no explicit set of accept states.
num_states: int
transitions: dict[State, dict[Symbol, State]]

@staticmethod
def from_dictionary(
base_alphabet: Iterable[Symbol], dictionary: Iterable[MergeRule]
) -> 'TokenDFA':
dfa = TokenDFA.from_base_alphabet(base_alphabet)
for rule in dictionary:
dfa.merge_rule(rule)
# TODO Add trimming? Find all states not reachable from start state.
# This can probably be done during construction without needing to rescan
# the automaton from scratch every time.
return dfa

@staticmethod
def from_base_alphabet(base_alphabet: Iterable[Symbol]) -> 'TokenDFA':
return TokenDFA(num_states=1, transitions={0: {a: 0 for a in base_alphabet}})

def new_states(self, n: int) -> range:
lo = self.num_states
self.num_states += n
return range(lo, self.num_states)

def new_symbol(self) -> Symbol:
a = self.alphabet_size
self.alphabet_size += 1
return a

def get_state_to(self, state_from: State, symbol: Symbol) -> State | None:
d = self.transitions.get(state_from)
if d is not None:
return d.get(symbol)

def get_transitions_from_state(self, state_from: State) -> Iterable[Symbol, State]:
def get_transitions_from_state(
self, state_from: State
) -> Iterable[tuple[Symbol, State]]:
d = self.transitions.get(state_from)
if d is not None:
return d.items()
else:
return ()

def get_transitions(self) -> Iterable[tuple[State, Symbol, State]]:
for state_from, transitions_from_state in self.transitions.items():
for symbol, state_to in transitions_from_state.items():
yield state_from, symbol, state_to

def set_transition(self, state_from: State, symbol: Symbol, state_to: State) -> None:
if state_from not in self.transitions:
self.transitions[state_from] = collections.OrderedDict()
self.transitions[state_from] = {}
self.transitions[state_from][symbol] = state_to


def construct_base_token_dfa(alphabet_size: int) -> DFA:
return DFA(
num_states=1,
alphabet_size=alphabet_size,
transitions={0: collections.OrderedDict.fromkeys(range(alphabet_size), 0)},
)


def merge_rule_into_token_dfa(dfa: DFA, rule: MergeRule) -> None:
# This implements Algorithm 2 from https://arxiv.org/pdf/2405.07671
u, v = rule
uv = dfa.new_symbol()
# Use OrderedDict to ensure deterministic iteration order.
S2 = collections.OrderedDict()
for s1 in range(dfa.num_states):
s2 = dfa.get_state_to(s1, u)
if s2 is not None:
s3 = dfa.get_state_to(s2, v)
if s3 is not None:
dfa.set_transition(s1, uv, s3)
S2[s2] = True
fresh = dfa.new_states(len(S2))
excluded = [v]
if u == v:
excluded.append(uv)
for s2, fresh_s2 in zip(S2, fresh):
for alpha, state_to in dfa.get_transitions_from_state(s2):
if alpha not in excluded:
dfa.set_transition(fresh_s2, alpha, state_to)
state_to_fresh = dict(zip(S2, fresh))
for q in range(dfa.num_states):
state_to = dfa.get_state_to(q, u)
if state_to is not None:
fresh_state_to = state_to_fresh.get(state_to)
if fresh_state_to is not None:
dfa.set_transition(q, u, fresh_state_to)


def construct_token_dfa(alphabet_size: int, dictionary: Iterable[MergeRule]) -> DFA:
dfa = construct_base_token_dfa(alphabet_size)
for rule in dictionary:
merge_rule_into_token_dfa(dfa, rule)
# TODO Add trimming? Find all states not reachable from start state.
# This can probably be done during construction without needing to rescan
# the automaton from scratch every time.
return dfa


def get_int_mapping(
alphabet: Iterable[str], dictionary: Iterable[tuple[str, str]]
) -> Iterable[str]:
return itertools.chain(alphabet, (u + v for u, v in dictionary))
def merge_rule(self, rule: MergeRule) -> None:
# This implements Algorithm 2 of https://arxiv.org/pdf/2405.07671
u, v, uv = rule
# Use dict to ensure deterministic iteration order.
S2 = {}
for s1 in range(self.num_states):
s2 = self.get_state_to(s1, u)
if s2 is not None:
s3 = self.get_state_to(s2, v)
if s3 is not None:
self.set_transition(s1, uv, s3)
S2[s2] = True
fresh = self.new_states(len(S2))
excluded = [v]
if u == v:
excluded.append(uv)
for s2, fresh_s2 in zip(S2, fresh):
for alpha, state_to in self.get_transitions_from_state(s2):
if alpha not in excluded:
self.set_transition(fresh_s2, alpha, state_to)
state_to_fresh = dict(zip(S2, fresh))
for q in range(self.num_states):
state_to = self.get_state_to(q, u)
if state_to is not None:
fresh_state_to = state_to_fresh.get(state_to)
if fresh_state_to is not None:
self.set_transition(q, u, fresh_state_to)
67 changes: 67 additions & 0 deletions genparse/canonical_tokenization/bpe_canonicalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import dataclasses
import json
import math
import pathlib

import torch

from genparse.canonical_tokenization.canonicalizer import (
NonCanonicalTokenizationError,
CanonicalizerIterator,
Canonicalizer,
)


@dataclasses.dataclass
class BPECanonicalizerIterator(CanonicalizerIterator):
parent: 'BPECanonicalizer'
state: int

def next(self, token):
new_state = self.parent.transitions.get((self.state, token))
if new_state is None:
raise NonCanonicalTokenizationError
self.state = new_state

def mask(self):
return self.parent.mask_tensor[self.state]


@dataclasses.dataclass
class BPECanonicalizer(Canonicalizer):
transitions: dict[tuple[int, int], int]
mask_tensor: torch.Tensor
_eos_token_id: int

def iterator(self):
return BPECanonicalizerIterator(self, 0)

def eos_token_id(self):
return self._eos_token_id

@staticmethod
def from_file(
path: pathlib.Path, dtype: torch.dtype, device: torch.device
) -> 'BPECanonicalizer':
data = torch.load(path)
eos_token_id = data['eos_token_id']
transition_list = data['transitions'].tolist()
# Index the transitions by source state and symbol.
transitions = {
(state_from, symbol): state_to
for state_from, symbol, state_to in transition_list
}
# Precompute the mask tensors.
mask_tensor = torch.full(
(data['num_states'], data['vocabulary_size']),
-math.inf,
dtype=dtype,
device=device,
)
# Conveniently, the keys of the transitions dict are the coordinates
# of the entries that should be 0.
index_tensor = torch.tensor(list(transitions.keys()), device=device)
mask_tensor[torch.unbind(index_tensor, dim=1)] = 0
# All states are accept states, so EOS is allowed at every state.
mask_tensor[:, eos_token_id] = 0
return BPECanonicalizer(transitions, mask_tensor, eos_token_id)
54 changes: 54 additions & 0 deletions genparse/canonical_tokenization/canonicalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import math
from collections.abc import Iterable

import torch


class NonCanonicalTokenizationError(ValueError):
pass


class CanonicalizerIterator:
def next(self, token: int) -> None:
"""Read a token and update the state of this iterator accordingly.
Raise NonCanonicalTokenizationError if this token would result in a
non-canonical token sequence."""
# Raises if the token sequence is invalid
raise NotImplementedError

def mask(self) -> torch.Tensor:
"""Return a mask representing which tokens may validly be generated
next given the current state of this iterator.
The result is a tensor with 0 at indexes for valid tokens, and -inf at
indexes for invalid tokens.
"""
raise NotImplementedError


class Canonicalizer:
def masks(self, tokens: Iterable[int]) -> Iterable[torch.Tensor]:
"""Read n tokens and generate a sequence of n+1 masks.
The eos_token index is needed to check that the token sequence is
allowed to end when it does.
"""
it = self.iterator()
mask = it.mask()
yield mask
for token in tokens:
it.next(token)
mask = it.mask()
yield mask
if mask[self.eos_token_id()].item() == -math.inf:
raise NonCanonicalTokenizationError('this token sequence cannot end with EOS')

def iterator(self) -> CanonicalizerIterator:
"""Return a CanonicalizerIterator that can be used to compute masks
iteratively, e.g., during decoding."""
raise NotImplementedError

def eos_token_id(self) -> int:
"""Get the id of the EOS token."""
raise NotImplementedError
Loading

0 comments on commit e374ad6

Please sign in to comment.