Skip to content

Commit

Permalink
use EarleyLM in tests/test_inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 29, 2024
1 parent 09682a2 commit 4145d17
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
4 changes: 1 addition & 3 deletions genparse/cfglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genparse.cfg import CFG, _gen_nt
from genparse.lm import LM
from genparse.semiring import Boolean, Float
from genparse.experimental.cky import IncrementalCKY

# EOS = '$EOS'
# EOS = '🛑'
Expand Down Expand Up @@ -105,9 +106,6 @@ def __init__(self, cfg, **kwargs):

self.cfg = cfg
self.pfg = self.cfg.cnf.prefix_grammar.cnf

from genparse.experimental.cky import IncrementalCKY

self.model = IncrementalCKY(self.pfg, **kwargs)

super().__init__(V=cfg.V, eos=EOS)
Expand Down
14 changes: 10 additions & 4 deletions genparse/experimental/earley.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,35 @@
# from arsenal.datastructures.pdict import pdict
from arsenal.datastructures.heap import LocatorMaxHeap

from genparse.cfglm import EOS, add_EOS
from genparse.cfglm import EOS, add_EOS, locally_normalize, CFG
from genparse.linear import WeightedGraph
from genparse.lm import LM
from genparse.semiring import Boolean
from genparse import Float


class EarleyLM(LM):
def __init__(self, cfg):
if EOS not in cfg.V:
cfg = add_EOS(cfg)
self.cfg = cfg # Note: <- cfg before prefix transform & normalization!
self.model = Earley(cfg.prefix_grammar)
super().__init__(V=cfg.V, eos=EOS)

def p_next(self, context):
return self.model.p_next(context)

# def __call__(self, context):
# assert context[-1] == EOS
# return self.p_next(context[:-1])[EOS]
def __call__(self, context):
assert context[-1] == EOS
return self.model(context)

def clear_cache(self):
self.model.clear_cache()

@classmethod
def from_string(cls, x, semiring=Float, **kwargs):
return cls(locally_normalize(CFG.from_string(x, semiring), **kwargs))


class Column:
__slots__ = ('k', 'i_chart', 'c_chart', 'waiting_for', 'Q')
Expand Down
3 changes: 2 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pandas as pd
from arsenal import colors

from genparse.cfglm import CFGLM
# from genparse.cfglm import CFGLM
from genparse.experimental.earley import EarleyLM as CFGLM
from genparse.semiring import Float
from genparse.steer import BruteForceGlobalProductOfExperts, run

Expand Down

0 comments on commit 4145d17

Please sign in to comment.