Skip to content

Commit

Permalink
improving coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 29, 2024
1 parent 161e262 commit 4e0c6a2
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 56 deletions.
6 changes: 3 additions & 3 deletions genparse/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def nullaryremove(self, binarize=True, trim=True, **kwargs):
# A really wide rule can take a very long time because of the power set
# in this rule so it is really important to binarize.
if binarize:
self = self.binarize()
self = self.binarize() # pragma: no cover
self = self.separate_start()
tmp = self._push_null_weights(self.null_weight(), **kwargs)
return tmp.trim() if trim else tmp
Expand Down Expand Up @@ -643,8 +643,8 @@ def _find_invalid_cnf_rule(self):
else:
yield r

def has_nullary(self):
return any((len(p.body) == 0) for p in self if p.head != self.S)
# def has_nullary(self):
# return any((len(p.body) == 0) for p in self if p.head != self.S)

def has_unary_cycle(self):
f = self._unary_graph().buckets
Expand Down
2 changes: 1 addition & 1 deletion genparse/proposal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _make_guide(guide_spec):
elif isinstance(guide_spec, LM):
return guide_spec
else:
raise ValueError('Unknown guide specification')
raise ValueError('Unknown guide specification') # pragma: no cover


def _make_mock_llm(V, uniform):
Expand Down
12 changes: 4 additions & 8 deletions tests/test_fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ def test_fst_cfg1():
have = (
CFG.from_string(
"""
1: S -> a b c
""",
1: S -> a b c
""",
Float,
)
@ fst
Expand All @@ -34,10 +32,8 @@ def test_fst_cfg1():
# apply from the right of the transducer
have = fst @ CFG.from_string(
"""
1: S -> A B C
""",
1: S -> A B C
""",
Float,
is_terminal=lambda X: X in 'ABC',
)
Expand Down
99 changes: 55 additions & 44 deletions tests/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,29 @@ def assert_equal(have, want, tol=1e-5):


# reference implementation of the intersection algorithm
def intersect_slow(self, fsa):
fst = FST.diag(fsa)
def intersect_slow(self, fst):
# coerce something sequence like into a diagonal FST
if isinstance(fst, (str, list, tuple)):
fst = FST.from_string(fst, self.R)

# coerce something FSA-like into an FST, might throw an error
if not isinstance(fst, FST):
fst = FST.diag(fst)

return compose_naive_epsilon(self, fst)


def compose_naive_epsilon(self, fst):
"Reference implementation of the grammar-transducer composition."

# coerce something sequence like into a diagonal FST
if isinstance(fst, (str, list, tuple)):
fst = FST.from_string(fst, self.R)

# coerce something FSA-like into an FST, might throw an error
if not isinstance(fst, FST):
fst = FST.diag(fst)

new_start = self.S
new = self.spawn(S=new_start)

Expand Down Expand Up @@ -103,10 +116,10 @@ def check_fst(cfg, fst):
def test_palindrome1():
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Float,
)

Expand All @@ -118,10 +131,10 @@ def test_palindrome1():
def test_palindrome2():
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Real,
)

Expand All @@ -139,10 +152,10 @@ def test_palindrome2():
def test_palindrome3():
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Real,
)

Expand All @@ -164,25 +177,25 @@ def test_palindrome3():
def test_catalan1():
cfg = CFG.from_string(
"""
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
Real,
)

fsa = WFSA.from_string('aa', cfg.R)
# fsa = WFSA.from_string('aa', cfg.R)

check(cfg, fsa)
check(cfg, 'aa')


def test_catalan2():
cfg = CFG.from_string(
"""
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
Real,
)

Expand Down Expand Up @@ -235,10 +248,10 @@ def check(cfg, fsa):
def test_catalan_fst():
cfg = CFG.from_string(
"""
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
0.4: S -> S S
0.3: S -> a
0.3: S -> b
""",
Real,
)

Expand All @@ -258,10 +271,10 @@ def test_catalan_fst():
def test_palindrome_fst():
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Real,
)

Expand All @@ -284,10 +297,10 @@ def test_palindrome_fst():
def test_epsilon_fst():
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Real,
)

Expand Down Expand Up @@ -320,10 +333,10 @@ def test_epsilon_fst_2():
# This test case is a bit more complex as it contains epsilon cycles on the FST
cfg = CFG.from_string(
"""
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
0.3: S -> a S a
0.4: S -> b S b
0.3: S ->
""",
Real,
)

Expand Down Expand Up @@ -356,10 +369,8 @@ def test_epsilon_fst_2():
def test_simple_epsilon():
g = CFG.from_string(
"""
1: S -> a
""",
1: S -> a
""",
Float,
)

Expand Down
14 changes: 14 additions & 0 deletions tests/test_wcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def test_misc():
else:
raise AssertionError('test failed')

cfg = CFG.from_string(
"""
2: X -> a
3: Y -> b
""",
Float,
)
cfg['Y'].trim().assert_equal("""
3: Y -> b
""")

# call it twice to hit the trim cache
cfg.trim().trim()


def test_agenda_misc():
# test stopping early
Expand Down

0 comments on commit 4e0c6a2

Please sign in to comment.