Skip to content

Commit

Permalink
Merge branch 'main' of github.com:probcomp/genparse
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 28, 2024
2 parents 5f04401 + 1100e60 commit 77418c9
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 2 deletions.
98 changes: 98 additions & 0 deletions genparse/canonical_tokenization/berglund.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import collections
import dataclasses
import itertools
from collections.abc import Iterable, Sequence

State = int
Symbol = int
MergeRule = tuple[Symbol, Symbol]


@dataclasses.dataclass
class DFA:
num_states: int
alphabet_size: int
transitions: dict[State, collections.OrderedDict[Symbol, State]]
# Note that in this construction, all states are accept states, so there
# is no explicit set of accept states.

def new_states(self, n: int) -> range:
lo = self.num_states
self.num_states += n
return range(lo, self.num_states)

def new_symbol(self) -> Symbol:
a = self.alphabet_size
self.alphabet_size += 1
return a

def get_state_to(self, state_from: State, symbol: Symbol) -> State | None:
d = self.transitions.get(state_from)
if d is not None:
return d.get(symbol)

def get_transitions_from_state(self, state_from: State) -> Iterable[Symbol, State]:
d = self.transitions.get(state_from)
if d is not None:
return d.items()
else:
return ()

def set_transition(self, state_from: State, symbol: Symbol, state_to: State) -> None:
if state_from not in self.transitions:
self.transitions[state_from] = collections.OrderedDict()
self.transitions[state_from][symbol] = state_to


def construct_base_token_dfa(alphabet_size: int) -> DFA:
return DFA(
num_states=1,
alphabet_size=alphabet_size,
transitions={0: collections.OrderedDict.fromkeys(range(alphabet_size), 0)},
)


def merge_rule_into_token_dfa(dfa: DFA, rule: MergeRule) -> None:
# This implements Algorithm 2 from https://arxiv.org/pdf/2405.07671
u, v = rule
uv = dfa.new_symbol()
# Use OrderedDict to ensure deterministic iteration order.
S2 = collections.OrderedDict()
for s1 in range(dfa.num_states):
s2 = dfa.get_state_to(s1, u)
if s2 is not None:
s3 = dfa.get_state_to(s2, v)
if s3 is not None:
dfa.set_transition(s1, uv, s3)
S2[s2] = True
fresh = dfa.new_states(len(S2))
excluded = [v]
if u == v:
excluded.append(uv)
for s2, fresh_s2 in zip(S2, fresh):
for alpha, state_to in dfa.get_transitions_from_state(s2):
if alpha not in excluded:
dfa.set_transition(fresh_s2, alpha, state_to)
state_to_fresh = dict(zip(S2, fresh))
for q in range(dfa.num_states):
state_to = dfa.get_state_to(q, u)
if state_to is not None:
fresh_state_to = state_to_fresh.get(state_to)
if fresh_state_to is not None:
dfa.set_transition(q, u, fresh_state_to)


def construct_token_dfa(alphabet_size: int, dictionary: Iterable[MergeRule]) -> DFA:
dfa = construct_base_token_dfa(alphabet_size)
for rule in dictionary:
merge_rule_into_token_dfa(dfa, rule)
# TODO Add trimming? Find all states not reachable from start state.
# This can probably be done during construction without needing to rescan
# the automaton from scratch every time.
return dfa


def get_int_mapping(
alphabet: Iterable[str], dictionary: Iterable[tuple[str, str]]
) -> Iterable[str]:
return itertools.chain(alphabet, (u + v for u, v in dictionary))
5 changes: 3 additions & 2 deletions genparse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
grammar,
proposal_name='character',
seed=None,
batch_size=None,
guide_opts=None,
proposal_opts=None,
):
Expand All @@ -82,7 +83,7 @@ def __init__(
if seed is not None:
set_seed(seed)

llm = load_model_by_name(model_name)
llm = load_model_by_name(model_name, batch_size=batch_size)
guide = lark_guide(grammar, **guide_opts)
sampler = HFPPLSampler(llm=llm, guide=guide)

Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(
import transformers
from genparse.lm import AsyncGreedilyTokenizedLLM

if model_name == 'gpt':
if model_name == 'gpt2':
MODEL_ID = 'gpt2'
llm = AsyncGreedilyTokenizedLLM(
model=vllmpplLLM(MODEL_ID),
Expand Down
89 changes: 89 additions & 0 deletions tests/test_berglund_bpe_dfa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from genparse.canonical_tokenization.berglund import (
DFA,
construct_base_token_dfa,
merge_rule_into_token_dfa,
get_int_mapping,
)


def are_isomorphic(dfa1, dfa2):
if not (
dfa1.num_states == dfa2.num_states and dfa1.alphabet_size == dfa2.alphabet_size
):
return False
agenda = [0]
state_mapping = {0: 0}
while agenda:
q1 = agenda.pop()
for a in range(dfa1.alphabet_size):
r1 = dfa1.get_state_to(q1, a)
r2 = dfa2.get_state_to(state_mapping[q1], a)
if r1 is None:
if r2 is not None:
return False
else:
if r2 is None:
return False
else:
if r1 in state_mapping:
if r2 != state_mapping[r1]:
return False
else:
state_mapping[r1] = r2
agenda.append(r1)
return True


def test_example_2():
# This tests Example 2 of https://arxiv.org/abs/2405.07671

alphabet = ['a', 'b']
dictionary = [('a', 'a'), ('b', 'a')]

int_to_str = get_int_mapping(alphabet, dictionary)
str_to_int = {s: i for i, s in enumerate(int_to_str)}
dictionary_as_int = [(str_to_int[u], str_to_int[v]) for u, v in dictionary]

def construct_dfa(num_states, alphabet_size, transitions):
M = DFA(num_states=num_states, alphabet_size=alphabet_size, transitions={})
for q, a, r in transitions:
M.set_transition(q, str_to_int[a], r)
return M

M = construct_base_token_dfa(len(alphabet))
N = construct_dfa(
num_states=1, alphabet_size=2, transitions=[(0, a, 0) for a in alphabet]
)
assert are_isomorphic(M, N)

merge_rule_into_token_dfa(M, dictionary_as_int[0])
N = construct_dfa(
num_states=2,
alphabet_size=3,
transitions=[(0, 'aa', 0), (0, 'b', 0), (0, 'a', 1), (1, 'b', 0)],
)
assert are_isomorphic(M, N)

merge_rule_into_token_dfa(M, dictionary_as_int[1])
N = construct_dfa(
num_states=3,
alphabet_size=4,
transitions=[
(0, 'aa', 0),
(0, 'b', 2),
(0, 'a', 1),
(0, 'ba', 1),
(1, 'ba', 1),
(1, 'b', 2),
(2, 'b', 2),
(2, 'aa', 0),
(2, 'ba', 1),
],
)
assert are_isomorphic(M, N)


if __name__ == '__main__':
from arsenal import testing_framework

testing_framework(globals())

0 comments on commit 77418c9

Please sign in to comment.