Skip to content

Commit

Permalink
improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 29, 2024
1 parent 4e0c6a2 commit 1a58991
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 98 deletions.
11 changes: 8 additions & 3 deletions genparse/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def trim(self, bottomup_only=False):
agenda.add(e.head)

if bottomup_only:
return self._trim(C)
val = self._trim(C)
self._trim_cache[bottomup_only] = val
val._trim_cache[bottomup_only] = val
return val

T = {self.S}
agenda.update(T)
Expand All @@ -316,8 +319,10 @@ def trim(self, bottomup_only=False):
T.add(b)
agenda.add(b)

self._trim_cache[bottomup_only] = self._trim(T)
return self._trim_cache[bottomup_only]
val = self._trim(T)
self._trim_cache[bottomup_only] = val
val._trim_cache[bottomup_only] = val
return val

def cotrim(self):
return self.trim(bottomup_only=True)
Expand Down
2 changes: 1 addition & 1 deletion genparse/proposal/trie_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _update_trie_numba(
jump: List[numba.int32[:]],
token_id_to_leaf: numba.int32[:, :],
ordering: numba.int32[:],
):
): # pragma: no cover
# update leaves
M = token_id_to_leaf.shape[0]
for k in range(M):
Expand Down
12 changes: 2 additions & 10 deletions genparse/steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,8 @@ def __init__(self, lm1, lm2, MAX_LENGTH):
# materializing the distrbution over strings up to a maximum length
self.lm1 = lm1
self.lm2 = lm2
self.p1 = (
lm1.cfg.cnf.language(MAX_LENGTH)
.filter(lambda x: len(x) <= MAX_LENGTH)
.normalize()
)
self.p2 = (
lm2.cfg.cnf.language(MAX_LENGTH)
.filter(lambda x: len(x) <= MAX_LENGTH)
.normalize()
)
self.p1 = lm1.cfg.cnf.materialize(MAX_LENGTH).normalize()
self.p2 = lm2.cfg.cnf.materialize(MAX_LENGTH).normalize()
self.target = (self.p1 * self.p2).normalize()


Expand Down
179 changes: 95 additions & 84 deletions tests/test_wcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,25 @@ def test_misc():

cfg = CFG.from_string(
"""
2: X -> a
3: Y -> b
""",
2: X -> a
3: Y -> b
""",
Float,
)
cfg['Y'].trim().assert_equal("""
3: Y -> b
""")
cfg['Y'].trim().assert_equal('3: Y -> b')

# call it twice to hit the trim cache
cfg.trim().trim()
cfg.trim().trim() # serial
cfg.trim() # parallel


def test_agenda_misc():
# test stopping early
g = CFG.from_string(
"""
0.5: S → a S
1: S → a
""",
0.5: S → a S
1: S → a
""",
Float,
)

Expand Down Expand Up @@ -243,11 +240,9 @@ def test_semirings():
def test_treesum():
cfg = CFG.from_string(
"""
0.25: S → S S
0.75: S → a
""",
0.25: S → S S
0.75: S → a
""",
Real,
)

Expand All @@ -269,18 +264,18 @@ def test_trim():
cfg = CFG.from_string(
"""
0.25: S → S S
0.75: S → a
0.25: S → S S
0.75: S → a
0.75: A → a
0.75: A → a
1: C → D
1: D → C
1: C → D
1: D → C
1: B → a
1: B → B
1: B → a
1: B → B
""",
""",
Real,
)

Expand All @@ -289,10 +284,10 @@ def test_trim():
want = CFG.from_string(
"""
0.25: S → S S
0.75: S → a
0.25: S → S S
0.75: S → a
""",
""",
Real,
)

Expand Down Expand Up @@ -320,20 +315,20 @@ def test_cnf():
cfg = CFG.from_string(
"""
1: S → S1
1: S → S1
1: S → A B C d
1: S → A B C d
0.5: S1 → S1
0.5: S1 → S1
0.1: S1 →
0.1: A →
0.1: S1 →
0.1: A →
1: A → a
1: B → d
1: C → c
1: A → a
1: B → d
1: C → c
""",
""",
Real,
)

Expand All @@ -349,18 +344,16 @@ def test_cnf():
def test_grammar_size_metrics():
cfg = CFG.from_string(
"""
1.0: S → A B C D
0.5: S → S
0.2: S →
0.1: A →
1: A → a
1: B → d
1: C → c
1: D → d
""",
1.0: S → A B C D
0.5: S → S
0.2: S →
0.1: A →
1: A → a
1: B → d
1: C → c
1: D → d
""",
Real,
)

Expand All @@ -371,12 +364,10 @@ def test_grammar_size_metrics():
def test_palindrome_derivations():
cfg = CFG.from_string(
"""
1: S → a S a
1: S → b S b
1: S → c
""",
1: S → a S a
1: S → b S b
1: S → c
""",
Real,
)

Expand All @@ -397,10 +388,10 @@ def test_palindrome_derivations():
def test_unfold():
cfg = CFG.from_string(
"""
1.0: S →
0.5: S → S a
0.5: B → b
""",
1.0: S →
0.5: S → S a
0.5: B → b
""",
Real,
)

Expand All @@ -413,13 +404,11 @@ def test_unfold():
new.assert_equal(
CFG.from_string(
"""
1.0: S →
0.5: S → a
0.25: S → S a a
0.5: B → b
""",
1.0: S →
0.5: S → a
0.25: S → S a a
0.5: B → b
""",
Real,
)
)
Expand All @@ -436,14 +425,14 @@ def test_unfold():
def test_cky():
cfg = CFG.from_string(
"""
1: S -> A B
0.1: A -> A B
0.4: A ->
0.5: A -> b
0.4: B -> a
0.5: B ->
0.1: B -> B A
""",
1: S -> A B
0.1: A -> A B
0.4: A ->
0.5: A -> b
0.4: B -> a
0.5: B ->
0.1: B -> B A
""",
Real,
)

Expand All @@ -467,21 +456,21 @@ def test_cky():
def test_unary_cycle_removal():
cfg = CFG.from_string(
"""
0.5: S → A1
0.5: S → A1
0.5: A1 → B1
0.5: B1 → C1
0.5: C1 → A1
0.5: A1 → B1
0.5: B1 → C1
0.5: C1 → A1
0.5: C1 → C
0.25: C1 → C1
0.5: C1 → C
0.25: C1 → C1
0.25: C1 → C0
1.0: C0 → C
0.25: C1 → C0
1.0: C0 → C
0.5: C → c
0.5: C → c
""",
""",
Float,
)

Expand All @@ -490,6 +479,28 @@ def test_unary_cycle_removal():
unaryfree.agenda().assert_equal(cfg.agenda(), domain=cfg.N, tol=1e-10, verbose=1)


def test_truncate_length():
cfg = CFG.from_string(
"""
1: S → a S a
1: S → b S b
1: S →
""",
Real,
)

max_length = 5

cfg_t = cfg.truncate_length(max_length)
have = cfg_t.language(max_length * 2)

want = cfg.materialize(max_length=max_length)

have.assert_equal(want)
print(have)
assert len(have) == 7 or max_length != 5


if __name__ == '__main__':
from arsenal import testing_framework

Expand Down

0 comments on commit 1a58991

Please sign in to comment.