From 66073d5f29c5b00338e33042a20464ba8f2479b5 Mon Sep 17 00:00:00 2001 From: Tim Vieira Date: Wed, 19 Jun 2024 19:23:59 -0400 Subject: [PATCH] whoops; forget to add file --- genparse/experimental/cky.py | 120 +++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 genparse/experimental/cky.py diff --git a/genparse/experimental/cky.py b/genparse/experimental/cky.py new file mode 100644 index 00000000..1a60b714 --- /dev/null +++ b/genparse/experimental/cky.py @@ -0,0 +1,120 @@ +from collections import defaultdict + + +class IncrementalCKY: + def __init__(self, cfg): + cfg = cfg.renumber() + self.cfg = cfg + self.S = cfg.S + + # cache columns of the chart indexed by prefix + self._chart = {} + + [self.nullary, self.terminal, binary] = cfg._cnf + r_y_xz = defaultdict(list) + for r in binary: # binary rules + r_y_xz[r.body[0]].append(r) + self.r_y_xz = r_y_xz + + def clear_cache(self): + self._chart.clear() + + def __call__(self, x): + return self.chart(x)[len(x)][0][self.S] + + def p_next(self, prefix): + return self.next_token_weights(self.chart(prefix), prefix) + + def chart(self, prefix): + c = self._chart.get(prefix) + if c is None: + c = self._compute_chart(prefix) + self._chart[prefix] = c + return c + + def _compute_chart(self, prefix): + if len(prefix) == 0: + tmp = [defaultdict(self.cfg.R.chart)] + tmp[0][0][self.cfg.S] = self.nullary + return tmp + else: + chart = self.chart(prefix[:-1]) + last_chart = self.extend_chart(chart, prefix) + return chart + [ + last_chart + ] # TODO: avoid list addition here as it is not constant time! + + def next_token_weights(self, chart, prefix): + """ + An O(N²) time algorithm to the total weight of a each next-token + extension of `prefix`. + """ + k = len(prefix) + 1 + + cfg = self.cfg + terminal = self.terminal + r_y_xz = self.r_y_xz + + # the code below is just backprop / outside algorithm + α = defaultdict(cfg.R.chart) + α[0][cfg.S] += cfg.R.one + + # Binary rules + for span in reversed(range(2, k + 1)): + i = k - span + α_i = α[i] + for j in range(i + 1, k): + chart_ij = chart[j][i] + + α_j = α[j] + for Y, y in chart_ij.items(): + for r in r_y_xz[Y]: + X = r.head + Z = r.body[1] + α_j[Z] += r.w * y * α_i[X] + + # Preterminal + q = cfg.R.chart() + tmp = α[k - 1] + for w in cfg.V: + for r in terminal[w]: + q[w] += r.w * tmp[r.head] + + return q + + def extend_chart(self, chart, prefix): + """ + An O(N²) time algorithm to extend to the `chart` with the last token + appearing at the end of `prefix`; returns a new chart column. + """ + k = len(prefix) + + cfg = self.cfg + r_y_xz = self.r_y_xz + + new = defaultdict(cfg.R.chart) + + # Nullary + new[k][cfg.S] += self.nullary + + # Preterminal + tmp = new[k - 1] + for r in self.terminal[prefix[k - 1]]: + tmp[r.head] += r.w + + # Binary rules + for span in range(2, k + 1): + i = k - span + new_i = new[i] + for j in range(i + 1, k): + chart_ij = chart[j][i] + new_j = new[j] + for Y, y in chart_ij.items(): + for r in r_y_xz[Y]: + X = r.head + Z = r.body[1] + z = new_j[Z] + x = r.w * y * z + new_i[X] += x + + return new