From c922f37dbb3416de280f5c8265fe67750a4c20c8 Mon Sep 17 00:00:00 2001 From: Tim Vieira Date: Tue, 18 Jun 2024 13:45:01 -0400 Subject: [PATCH 1/5] minor tweak to earley --- genparse/experimental/earley.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/genparse/experimental/earley.py b/genparse/experimental/earley.py index 5d669588..fbcb1886 100644 --- a/genparse/experimental/earley.py +++ b/genparse/experimental/earley.py @@ -206,14 +206,7 @@ def PREDICT(self, col): rhs = self.rhs for X in reachable: for w, Ys in rhs.get(X, ()): - item = (k, X, Ys) - was = col_chart.get(item) - if was is None: - Y = self.first_Ys[Ys] - col_waiting_for[Y].add(item) - col_chart[item] = w - else: - col_chart[item] = was + w + self._update(col, k, X, Ys, w) def _update(self, col, I, X, Ys, value): K = col.k @@ -287,15 +280,18 @@ def next_token_weights(self, cols): for (I, X, Ys) in col_waiting_for[Y]: if self.unit_Ys[Ys]: node = (I, X) - value = q.get(node) - if value is None: - value = self._helper(node, cols, q) + value = self._helper(node, cols, q) total += col_i_chart[I, X, Ys] * value p[Y] = total return p def _helper(self, top, cols, q): + + value = q.get(top) + if value is not None: + return value + zero = self.cfg.R.zero stack = [Node(top, None, zero)] From 86b6d4e9d2fc365df04292128d0e79cb7fd74b9e Mon Sep 17 00:00:00 2001 From: Tim Vieira Date: Tue, 18 Jun 2024 13:49:53 -0400 Subject: [PATCH 2/5] yikes there are even more random seeds! --- genparse/steer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/genparse/steer.py b/genparse/steer.py index c718a390..006349b9 100644 --- a/genparse/steer.py +++ b/genparse/steer.py @@ -29,9 +29,11 @@ def set_seed(seed): random.seed(seed) + np.random.seed(seed) torch.manual_seed(seed) transformers.set_seed(seed) - np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) # ____________________________________________________________________________________ From a6c3bcb7916999af2ca6a78ad0f3b18a43373fdb Mon Sep 17 00:00:00 2001 From: benlipkin Date: Tue, 18 Jun 2024 16:24:25 -0400 Subject: [PATCH 3/5] linter ignore notes, bench, benchmark; only track genparse, tests --- genparse/experimental/earley.py | 40 ++++++++++++++++++++------------- genparse/lm.py | 1 + ruff.toml | 3 +++ 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/genparse/experimental/earley.py b/genparse/experimental/earley.py index fbcb1886..552b74ae 100644 --- a/genparse/experimental/earley.py +++ b/genparse/experimental/earley.py @@ -1,10 +1,9 @@ import numpy as np -from arsenal import Integerizer, colors +from arsenal import Integerizer from collections import defaultdict -from functools import lru_cache -#from arsenal.datastructures.pdict import pdict +# from arsenal.datastructures.pdict import pdict from arsenal.datastructures.heap import LocatorMaxHeap from genparse.cfglm import EOS, add_EOS @@ -14,7 +13,6 @@ class EarleyLM(LM): - def __init__(self, cfg): if EOS not in cfg.V: cfg = add_EOS(cfg) @@ -33,7 +31,7 @@ def clear_cache(self): class Column: - __slots__ = ("k", "i_chart", "c_chart", "waiting_for", "Q") + __slots__ = ('k', 'i_chart', 'c_chart', 'waiting_for', 'Q') def __init__(self, k): self.k = k @@ -45,7 +43,7 @@ def __init__(self, k): self.waiting_for = defaultdict(set) # priority queue used when first filling the column -# self.Q = pdict() + # self.Q = pdict() self.Q = LocatorMaxHeap() @@ -55,11 +53,23 @@ class Earley: Warning: Assumes that nullary rules and unary chain cycles have been removed """ - __slots__ = ("cfg", "order", "_chart", "V", "eos", "_initial_column", "R", 'rhs', - 'ORDER_MAX', 'intern_Ys', 'unit_Ys', 'first_Ys', 'rest_Ys') + __slots__ = ( + 'cfg', + 'order', + '_chart', + 'V', + 'eos', + '_initial_column', + 'R', + 'rhs', + 'ORDER_MAX', + 'intern_Ys', + 'unit_Ys', + 'first_Ys', + 'rest_Ys', + ) def __init__(self, cfg): - cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber() self.cfg = cfg @@ -96,7 +106,8 @@ def __init__(self, cfg): for X in self.cfg.N: self.rhs[X] = [] for r in self.cfg.rhs[X]: - if r.body == (): continue + if r.body == (): + continue self.rhs[X].append((r.w, intern_Ys(r.body))) self.first_Ys = np.zeros(len(intern_Ys), dtype=object) @@ -104,7 +115,7 @@ def __init__(self, cfg): self.unit_Ys = np.zeros(len(intern_Ys), dtype=int) for Ys, code in list(self.intern_Ys.items()): - self.unit_Ys[code] = (len(Ys) == 1) + self.unit_Ys[code] = len(Ys) == 1 if len(Ys) > 0: self.first_Ys[code] = Ys[0] self.rest_Ys[code] = intern_Ys(Ys[1:]) @@ -152,7 +163,6 @@ def p_next(self, prefix): return self.next_token_weights(self.chart(prefix)) def next_column(self, prev_cols, token): - prev_col = prev_cols[-1] next_col = Column(prev_cols[-1].k + 1) next_col_c_chart = next_col.c_chart @@ -277,7 +287,7 @@ def next_token_weights(self, cols): for Y in col_waiting_for: if is_terminal(Y): total = zero - for (I, X, Ys) in col_waiting_for[Y]: + for I, X, Ys in col_waiting_for[Y]: if self.unit_Ys[Ys]: node = (I, X) value = self._helper(node, cols, q) @@ -287,7 +297,6 @@ def next_token_weights(self, cols): return p def _helper(self, top, cols, q): - value = q.get(top) if value is not None: return value @@ -296,7 +305,7 @@ def _helper(self, top, cols, q): stack = [Node(top, None, zero)] while stack: - node = stack[-1] # 👀 + node = stack[-1] # 👀 # place neighbors above the node on the stack (J, Y) = node.node @@ -330,6 +339,7 @@ def _helper(self, top, cols, q): class Node: __slots__ = ('value', 'node', 'edges', 'cursor') + def __init__(self, node, edges, value): self.node = node self.edges = edges diff --git a/genparse/lm.py b/genparse/lm.py index 3b9ea360..86dbab3c 100644 --- a/genparse/lm.py +++ b/genparse/lm.py @@ -311,6 +311,7 @@ def __repr__(self): from functools import lru_cache + @lru_cache(None) def make_mock_llm(**kwargs): from genparse.util import hf_tokenizer diff --git a/ruff.toml b/ruff.toml index ef616296..86adfdb4 100644 --- a/ruff.toml +++ b/ruff.toml @@ -17,6 +17,9 @@ exclude = [ "profile.html", "profile.json", "*.log", + "notes", + "bench", + "benchmark", ] line-length = 90 From cb897783c94de53a3ef9b26119c33515af104850 Mon Sep 17 00:00:00 2001 From: Tim Vieira Date: Tue, 18 Jun 2024 16:36:42 -0400 Subject: [PATCH 4/5] fixed most of the lint errors (issue #12) --- bench/run_spider_llama2_chat.py | 4 +- bench/spider/content_encoder.py | 2 +- bench/spider/evaluation.py | 448 ++++++++++++++++------- bench/spider/evaluator.py | 11 +- bench/spider/interface.py | 25 +- bench/spider/process_sql.py | 199 ++++++---- benchmark/benchmark_inference.py | 2 - benchmark/sql_parsing_speed.py | 16 +- genparse/cfg.py | 14 +- genparse/experimental/earley.py | 42 ++- genparse/experimental/earley_rescaled.py | 2 - genparse/lm.py | 3 +- genparse/proposal/trie_numba.py | 6 +- genparse/record.py | 4 +- genparse/steer.py | 10 +- genparse/tokenization.py | 18 +- notes/LM-Fun.ipynb | 343 ----------------- notes/benchmark_hfppl.py | 12 - notes/fst_pruned_composition.py | 5 +- notes/grammar_processing_issues.ipynb | 14 - notes/hfppl.ipynb | 1 + notes/hfppl_benleb.ipynb | 13 +- notes/sql_debug.ipynb | 18 +- notes/test_grammar_coverage.py | 10 +- ruff.toml | 3 +- tests/test_inference.py | 2 - tests/test_wcfg.py | 2 +- tests/test_wfsa_field.py | 2 +- 28 files changed, 559 insertions(+), 672 deletions(-) delete mode 100644 notes/LM-Fun.ipynb diff --git a/bench/run_spider_llama2_chat.py b/bench/run_spider_llama2_chat.py index 5577f500..c6d69af1 100644 --- a/bench/run_spider_llama2_chat.py +++ b/bench/run_spider_llama2_chat.py @@ -128,9 +128,9 @@ def main(): model = 'meta-llama/Llama-2-7b-chat-hf' model = model.replace('7b', args.model_size) access_token = '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' - logger.info(f"using model {model}") + logger.info(f'using model {model}') - tokenizer = AutoTokenizer.from_pretrained(model, token=access_token) + # tokenizer = AutoTokenizer.from_pretrained(model, token=access_token) pipe = pipeline( 'text-generation', model=model, diff --git a/bench/spider/content_encoder.py b/bench/spider/content_encoder.py index 8a6276df..7043dce7 100644 --- a/bench/spider/content_encoder.py +++ b/bench/spider/content_encoder.py @@ -46,7 +46,7 @@ def is_number(s: str) -> bool: try: float(s.replace(',', '')) return True - except: + except ValueError: return False diff --git a/bench/spider/evaluation.py b/bench/spider/evaluation.py index f84ecd7a..c5d81f7e 100644 --- a/bench/spider/evaluation.py +++ b/bench/spider/evaluation.py @@ -20,13 +20,12 @@ ################################ from __future__ import print_function -import os, sys +import os import json import sqlite3 -import traceback import argparse -from bench.spider.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql +from bench.spider.process_sql import get_schema, Schema, get_sql # Flag to disable value evaluation DISABLE_VALUE = True @@ -34,15 +33,38 @@ DISABLE_DISTINCT = True -CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +CLAUSE_KEYWORDS = ( + 'select', + 'from', + 'where', + 'group', + 'order', + 'limit', + 'intersect', + 'union', + 'except', +) JOIN_KEYWORDS = ('join', 'on', 'as') -WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') -UNIT_OPS = ('none', '-', '+', "*", '/') +WHERE_OPS = ( + 'not', + 'between', + '=', + '>', + '<', + '>=', + '<=', + '!=', + 'in', + 'like', + 'is', + 'exists', +) +UNIT_OPS = ('none', '-', '+', '*', '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { - 'sql': "sql", - 'table_unit': "table_unit", + 'sql': 'sql', + 'table_unit': 'table_unit', } COND_OPS = ('and', 'or') @@ -51,8 +73,8 @@ HARDNESS = { - "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), - "component2": ('except', 'union', 'intersect') + 'component1': ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), + 'component2': ('except', 'union', 'intersect'), } @@ -67,9 +89,9 @@ def condition_has_like(conds): def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] - if val1 is not None and type(val1) is dict: + if val1 is not None and isinstance(val1, dict): return True - if val2 is not None and type(val2) is dict: + if val2 is not None and isinstance(val2, dict): return True return False @@ -97,15 +119,15 @@ def recall(count, total): def F1(acc, rec): if (acc + rec) == 0: return 0 - return (2. * acc * rec) / (acc + rec) + return (2.0 * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: - return 0,0,0 + return 0, 0, 0 elif count == pred_total: - return 1,1,1 - return 0,0,0 + return 1, 1, 1 + return 0, 0, 0 def eval_sel(pred, label): @@ -154,8 +176,8 @@ def eval_group(pred, label): pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 - pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] - label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] + pred_cols = [pred.split('.')[1] if '.' in pred else pred for pred in pred_cols] + label_cols = [label.split('.')[1] if '.' in label else label for label in label_cols] for col in pred_cols: if col in label_cols: cnt += 1 @@ -172,9 +194,11 @@ def eval_having(pred, label): pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] - if pred_total == label_total == 1 \ - and pred_cols == label_cols \ - and pred['having'] == label['having']: + if ( + pred_total == label_total == 1 + and pred_cols == label_cols + and pred['having'] == label['having'] + ): cnt = 1 return label_total, pred_total, cnt @@ -186,8 +210,14 @@ def eval_order(pred, label): pred_total = 1 if len(label['orderBy']) > 0: label_total = 1 - if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ - ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): + if ( + len(label['orderBy']) > 0 + and pred['orderBy'] == label['orderBy'] + and ( + (pred['limit'] is None and label['limit'] is None) + or (pred['limit'] is not None and label['limit'] is not None) + ) + ): cnt = 1 return label_total, pred_total, cnt @@ -199,16 +229,16 @@ def eval_and_or(pred, label): label_ao = set(label_ao) if pred_ao == label_ao: - return 1,1,1 - return len(pred_ao),len(label_ao),0 + return 1, 1, 1 + return len(pred_ao), len(label_ao), 0 def get_nestedSQL(sql): nested = [] for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: - if type(cond_unit[3]) is dict: + if isinstance(cond_unit[3], dict): nested.append(cond_unit[3]) - if type(cond_unit[4]) is dict: + if isinstance(cond_unit[4], dict): nested.append(cond_unit[4]) if sql['intersect'] is not None: nested.append(sql['intersect']) @@ -273,11 +303,29 @@ def get_keywords(sql): res.add('not') # in keyword - if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: + if ( + len( + [ + cond_unit + for cond_unit in cond_units + if cond_unit[1] == WHERE_OPS.index('in') + ] + ) + > 0 + ): res.add('in') # like keyword - if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: + if ( + len( + [ + cond_unit + for cond_unit in cond_units + if cond_unit[1] == WHERE_OPS.index('like') + ] + ) + > 0 + ): res.add('like') return res @@ -316,7 +364,9 @@ def count_component1(sql): ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] count += len([token for token in ao if token == 'or']) cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] - count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) + count += len( + [cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')] + ) return count @@ -333,8 +383,10 @@ def count_others(sql): agg_count += count_agg(sql['where'][::2]) agg_count += count_agg(sql['groupBy']) if len(sql['orderBy']) > 0: - agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + - [unit[2] for unit in sql['orderBy'][1] if unit[2]]) + agg_count += count_agg( + [unit[1] for unit in sql['orderBy'][1] if unit[1]] + + [unit[2] for unit in sql['orderBy'][1] if unit[2]] + ) agg_count += count_agg(sql['having']) if agg_count > 1: count += 1 @@ -356,6 +408,7 @@ def count_others(sql): class Evaluator: """A simple evaluator""" + def __init__(self): self.partial_scores = None @@ -365,16 +418,19 @@ def eval_hardness(self, sql): count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: - return "easy" - elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ - (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): - return "medium" - elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ - (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ - (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): - return "hard" + return 'easy' + elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or ( + count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0 + ): + return 'medium' + elif ( + (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) + or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) + or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1) + ): + return 'hard' else: - return "extra" + return 'extra' def eval_exact_match(self, pred, label): partial_scores = self.eval_partial_match(pred, label) @@ -394,39 +450,99 @@ def eval_partial_match(self, pred, label): label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['select'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) - res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['select(no AGG)'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['where'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) - res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['where(no OP)'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_group(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['group(no Having)'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_having(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['group'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_order(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['order'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_and_or(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['and/or'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_IUEN(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['IUEN'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } label_total, pred_total, cnt = eval_keywords(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) - res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + res['keywords'] = { + 'acc': acc, + 'rec': rec, + 'f1': f1, + 'label_total': label_total, + 'pred_total': pred_total, + } return res @@ -436,43 +552,73 @@ def isValidSQL(sql, db): cursor = conn.cursor() try: cursor.execute(sql) - except: + except Exception: return False return True def print_scores(scores, etype): levels = ['easy', 'medium', 'hard', 'extra', 'all'] - partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', - 'group', 'order', 'and/or', 'IUEN', 'keywords'] - - print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) + partial_types = [ + 'select', + 'select(no AGG)', + 'where', + 'where(no OP)', + 'group(no Having)', + 'group', + 'order', + 'and/or', + 'IUEN', + 'keywords', + ] + + print('{:20} {:20} {:20} {:20} {:20} {:20}'.format('', *levels)) counts = [scores[level]['count'] for level in levels] - print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) + print('{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}'.format('count', *counts)) - if etype in ["all", "exec"]: + if etype in ['all', 'exec']: print('===================== EXECUTION ACCURACY =====================') this_scores = [scores[level]['exec'] for level in levels] - print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) + print( + '{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}'.format( + 'execution', *this_scores + ) + ) - if etype in ["all", "match"]: + if etype in ['all', 'match']: print('\n====================== EXACT MATCHING ACCURACY =====================') exact_scores = [scores[level]['exact'] for level in levels] - print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) + print( + '{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}'.format( + 'exact match', *exact_scores + ) + ) print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] - print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + print( + '{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}'.format( + type_, *this_scores + ) + ) print('---------------------- PARTIAL MATCHING RECALL ----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] - print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + print( + '{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}'.format( + type_, *this_scores + ) + ) print('---------------------- PARTIAL MATCHING F1 --------------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] - print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + print( + '{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}'.format( + type_, *this_scores + ) + ) def evaluate(gold, predict, db_dir, etype, kmaps): @@ -486,23 +632,39 @@ def evaluate(gold, predict, db_dir, etype, kmaps): evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] - partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', - 'group', 'order', 'and/or', 'IUEN', 'keywords'] + partial_types = [ + 'select', + 'select(no AGG)', + 'where', + 'where(no OP)', + 'group(no Having)', + 'group', + 'order', + 'and/or', + 'IUEN', + 'keywords', + ] entries = [] scores = {} for level in levels: - scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} + scores[level] = {'count': 0, 'partial': {}, 'exact': 0.0} scores[level]['exec'] = 0 for type_ in partial_types: - scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} + scores[level]['partial'][type_] = { + 'acc': 0.0, + 'rec': 0.0, + 'f1': 0.0, + 'acc_count': 0, + 'rec_count': 0, + } eval_err_num = 0 for p, g in zip(plist, glist): p_str = p[0] g_str, db = g db_name = db - db = os.path.join(db_dir, db, db + ".sqlite") + db = os.path.join(db_dir, db, db + '.sqlite') schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) @@ -511,28 +673,22 @@ def evaluate(gold, predict, db_dir, etype, kmaps): try: p_sql = get_sql(schema, p_str) - except: + except Exception: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { - "except": None, - "from": { - "conds": [], - "table_units": [] - }, - "groupBy": [], - "having": [], - "intersect": None, - "limit": None, - "orderBy": [], - "select": [ - False, - [] - ], - "union": None, - "where": [] + 'except': None, + 'from': {'conds': [], 'table_units': []}, + 'groupBy': [], + 'having': [], + 'intersect': None, + 'limit': None, + 'orderBy': [], + 'select': [False, []], + 'union': None, + 'where': [], } eval_err_num += 1 - print("eval_err_num:{}".format(eval_err_num)) + print('eval_err_num:{}'.format(eval_err_num)) # rebuild sql for value evaluation kmap = kmaps[db_name] @@ -543,27 +699,31 @@ def evaluate(gold, predict, db_dir, etype, kmaps): p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) - if etype in ["all", "exec"]: + if etype in ['all', 'exec']: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1.0 scores['all']['exec'] += 1.0 - if etype in ["all", "match"]: + if etype in ['all', 'match']: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: - print("{} pred: {}".format(hardness,p_str)) - print("{} gold: {}".format(hardness,g_str)) - print("") + print('{} pred: {}'.format(hardness, p_str)) + print('{} gold: {}'.format(hardness, g_str)) + print('') scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: - scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores[hardness]['partial'][type_]['acc'] += partial_scores[type_][ + 'acc' + ] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: - scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores[hardness]['partial'][type_]['rec'] += partial_scores[type_][ + 'rec' + ] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] if partial_scores[type_]['pred_total'] > 0: @@ -574,39 +734,56 @@ def evaluate(gold, predict, db_dir, etype, kmaps): scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] - entries.append({ - 'predictSQL': p_str, - 'goldSQL': g_str, - 'hardness': hardness, - 'exact': exact_score, - 'partial': partial_scores - }) + entries.append( + { + 'predictSQL': p_str, + 'goldSQL': g_str, + 'hardness': hardness, + 'exact': exact_score, + 'partial': partial_scores, + } + ) for level in levels: if scores[level]['count'] == 0: continue - if etype in ["all", "exec"]: + if etype in ['all', 'exec']: scores[level]['exec'] /= scores[level]['count'] - if etype in ["all", "match"]: + if etype in ['all', 'match']: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: - scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ - scores[level]['partial'][type_]['acc_count'] * 1.0 + scores[level]['partial'][type_]['acc'] = ( + scores[level]['partial'][type_]['acc'] + / scores[level]['partial'][type_]['acc_count'] + * 1.0 + ) if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: - scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ - scores[level]['partial'][type_]['rec_count'] * 1.0 - if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: + scores[level]['partial'][type_]['rec'] = ( + scores[level]['partial'][type_]['rec'] + / scores[level]['partial'][type_]['rec_count'] + * 1.0 + ) + if ( + scores[level]['partial'][type_]['acc'] == 0 + and scores[level]['partial'][type_]['rec'] == 0 + ): scores[level]['partial'][type_]['f1'] = 1 else: - scores[level]['partial'][type_]['f1'] = \ - 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( - scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) + scores[level]['partial'][type_]['f1'] = ( + 2.0 + * scores[level]['partial'][type_]['acc'] + * scores[level]['partial'][type_]['rec'] + / ( + scores[level]['partial'][type_]['rec'] + + scores[level]['partial'][type_]['acc'] + ) + ) print_scores(scores, etype) @@ -621,7 +798,7 @@ def eval_exec_match(db, p_str, g_str, pred, gold): try: cursor.execute(p_str) p_res = cursor.fetchall() - except: + except Exception: return False cursor.execute(g_str) @@ -630,7 +807,11 @@ def eval_exec_match(db, p_str, g_str, pred, gold): def res_map(res, val_units): rmap = {} for idx, val_unit in enumerate(val_units): - key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) + key = ( + tuple(val_unit[1]) + if not val_unit[2] + else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) + ) rmap[key] = [r[idx] for r in res] return rmap @@ -645,11 +826,11 @@ def rebuild_cond_unit_val(cond_unit): return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit - if type(val1) is not dict: + if not isinstance(val1, dict): val1 = None else: val1 = rebuild_sql_val(val1) - if type(val2) is not dict: + if not isinstance(val2, dict): val2 = None else: val2 = rebuild_sql_val(val2) @@ -685,11 +866,15 @@ def rebuild_sql_val(sql): # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): - col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] + col_ids = [ + table_unit[1] + for table_unit in table_units + if table_unit[0] == TABLE_TYPE['table_unit'] + ] prefixs = [col_id[:-2] for col_id in col_ids] - valid_col_units= [] + valid_col_units = [] for value in schema.idMap.values(): - if '.' in value and value[:value.index('.')] in prefixs: + if '.' in value and value[: value.index('.')] in prefixs: valid_col_units.append(value) return valid_col_units @@ -759,7 +944,10 @@ def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ - from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] + from_['table_units'] = [ + rebuild_table_unit_col(valid_col_units, table_unit, kmap) + for table_unit in from_['table_units'] + ] from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) return from_ @@ -768,7 +956,9 @@ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by - return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] + return [ + rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by + ] def rebuild_order_by_col(valid_col_units, order_by, kmap): @@ -776,7 +966,9 @@ def rebuild_order_by_col(valid_col_units, order_by, kmap): return order_by direction, val_units = order_by - new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] + new_val_units = [ + rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units + ] return direction, new_val_units @@ -798,8 +990,8 @@ def rebuild_sql_col(valid_col_units, sql, kmap): def build_foreign_key_map(entry): - cols_orig = entry["column_names_original"] - tables_orig = entry["table_names_original"] + cols_orig = entry['column_names_original'] + tables_orig = entry['table_names_original'] # rebuild cols corresponding to idmap in Schema cols = [] @@ -807,9 +999,9 @@ def build_foreign_key_map(entry): if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] - cols.append("__" + t.lower() + "." + c.lower() + "__") + cols.append('__' + t.lower() + '.' + c.lower() + '__') else: - cols.append("__all__") + cols.append('__all__') def keyset_in_list(k1, k2, k_list): for k_set in k_list: @@ -820,7 +1012,7 @@ def keyset_in_list(k1, k2, k_list): return new_k_set foreign_key_list = [] - foreign_keys = entry["foreign_keys"] + foreign_keys = entry['foreign_keys'] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) @@ -846,7 +1038,7 @@ def build_foreign_key_map_from_json(table): return tables -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--gold', dest='gold', type=str) parser.add_argument('--pred', dest='pred', type=str) @@ -861,7 +1053,7 @@ def build_foreign_key_map_from_json(table): table = args.table etype = args.etype - assert etype in ["all", "exec", "match"], "Unknown evaluation method" + assert etype in ['all', 'exec', 'match'], 'Unknown evaluation method' kmaps = build_foreign_key_map_from_json(table) diff --git a/bench/spider/evaluator.py b/bench/spider/evaluator.py index 4fa46f07..95d3ca1d 100644 --- a/bench/spider/evaluator.py +++ b/bench/spider/evaluator.py @@ -9,6 +9,7 @@ eval_exec_match, ) + class Evaluator: def __init__(self, spider_dir: Path): self.tables_path = spider_dir / 'tables.json' @@ -23,15 +24,15 @@ def evaluate(self, gold: str, pred: str, db_name: str): * `invalid` if `pred` sql is not a well-formed sql statement that can be parsed by sqlite * `mismatch` if `pred` is a well-formed sql but the execution result is different from that of the `gold`. """ - db = self.db_path / db_name / (db_name + ".sqlite") + db = self.db_path / db_name / (db_name + '.sqlite') schema = E.Schema(E.get_schema(db)) g_sql = E.get_sql(schema, gold) try: p_sql = E.get_sql(schema, pred) - except: + except Exception: # sql is ill-formed (can't be parsed by sqlite engine) - return False, "invalid" + return False, 'invalid' kmap = self.kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) @@ -42,6 +43,6 @@ def evaluate(self, gold: str, pred: str, db_name: str): p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) exec_match = eval_exec_match(db, pred, gold, p_sql, g_sql) - reason = None if exec_match else "mismatch" + reason = None if exec_match else 'mismatch' - return exec_match, reason \ No newline at end of file + return exec_match, reason diff --git a/bench/spider/interface.py b/bench/spider/interface.py index 2e9aee1c..a4d18f01 100644 --- a/bench/spider/interface.py +++ b/bench/spider/interface.py @@ -1,6 +1,5 @@ from path import Path -import bench from bench.spider.dialogue import load_spider_data from bench.spider.schema import load_schemas from bench.spider.evaluator import Evaluator @@ -9,16 +8,24 @@ class SpiderInterface: - def __init__(self, root=None): - if root is None: root = Path(genparse.__file__).dirname() / '..' / 'bench/spider/data/spider' + if root is None: + root = Path(genparse.__file__).dirname() / '..' / 'bench/spider/data/spider' if not root.exists(): - raise AssertionError('spider dataset not found follow the instruction in bench/README') - self.schemas = load_schemas(schemas_path=root / 'tables.json', db_path=root / 'database') + raise AssertionError( + 'spider dataset not found follow the instruction in bench/README' + ) + self.schemas = load_schemas( + schemas_path=root / 'tables.json', db_path=root / 'database' + ) self.evaluator = Evaluator(root) self.evaluate = self.evaluator.evaluate - self.dev_data = [SpiderExample(self, x) for x in load_spider_data(root / 'dev.json')] - self.train_data = [SpiderExample(self, x) for x in load_spider_data(root / 'train_spider.json')] + self.dev_data = [ + SpiderExample(self, x) for x in load_spider_data(root / 'dev.json') + ] + self.train_data = [ + SpiderExample(self, x) for x in load_spider_data(root / 'train_spider.json') + ] class SpiderExample: @@ -41,7 +48,9 @@ def describe_schema(self): for table in self.db_schema.tables: column_strs = [] for column in table.columns: - column_strs.append(f'* {column.name} ({column.tpe.value}): {column.nl_name}') + column_strs.append( + f'* {column.name} ({column.tpe.value}): {column.nl_name}' + ) table_str = '\n'.join([table.name] + column_strs) table_strs.append(table_str) return '\n\n'.join(table_strs) diff --git a/bench/spider/process_sql.py b/bench/spider/process_sql.py index 839612e6..2becb21c 100644 --- a/bench/spider/process_sql.py +++ b/bench/spider/process_sql.py @@ -28,15 +28,38 @@ import sqlite3 from nltk import word_tokenize -CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +CLAUSE_KEYWORDS = ( + 'select', + 'from', + 'where', + 'group', + 'order', + 'limit', + 'intersect', + 'union', + 'except', +) JOIN_KEYWORDS = ('join', 'on', 'as') -WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') -UNIT_OPS = ('none', '-', '+', "*", '/') +WHERE_OPS = ( + 'not', + 'between', + '=', + '>', + '<', + '>=', + '<=', + '!=', + 'in', + 'like', + 'is', + 'exists', +) +UNIT_OPS = ('none', '-', '+', '*', '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { - 'sql': "sql", - 'table_unit': "table_unit", + 'sql': 'sql', + 'table_unit': 'table_unit', } COND_OPS = ('and', 'or') @@ -44,11 +67,11 @@ ORDER_OPS = ('desc', 'asc') - class Schema: """ Simple schema which maps table&column to a unique identifier """ + def __init__(self, schema): self._schema = schema self._idMap = self._map(self._schema) @@ -62,15 +85,17 @@ def idMap(self): return self._idMap def _map(self, schema): - idMap = {'*': "__all__"} + idMap = {'*': '__all__'} id = 1 for key, vals in schema.items(): for val in vals: - idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" + idMap[key.lower() + '.' + val.lower()] = ( + '__' + key.lower() + '.' + val.lower() + '__' + ) id += 1 for key in schema: - idMap[key.lower()] = "__" + key.lower() + "__" + idMap[key.lower()] = '__' + key.lower() + '__' id += 1 return idMap @@ -94,7 +119,7 @@ def get_schema(db): # fetch table info for table in tables: - cursor.execute("PRAGMA table_info({})".format(table)) + cursor.execute('PRAGMA table_info({})'.format(table)) schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] return schema @@ -115,18 +140,18 @@ def get_schema_from_json(fpath): def tokenize(string): string = str(string) - string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? + string = string.replace("'", '"') # ensures all string values wrapped by "" problem?? quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] - assert len(quote_idxs) % 2 == 0, "Unexpected quote" + assert len(quote_idxs) % 2 == 0, 'Unexpected quote' # keep string value as token vals = {} - for i in range(len(quote_idxs)-1, -1, -2): - qidx1 = quote_idxs[i-1] + for i in range(len(quote_idxs) - 1, -1, -2): + qidx1 = quote_idxs[i - 1] qidx2 = quote_idxs[i] - val = string[qidx1: qidx2+1] - key = "__val_{}_{}__".format(qidx1, qidx2) - string = string[:qidx1] + key + string[qidx2+1:] + val = string[qidx1 : qidx2 + 1] + key = '__val_{}_{}__'.format(qidx1, qidx2) + string = string[:qidx1] + key + string[qidx2 + 1 :] vals[key] = val toks = [word.lower() for word in word_tokenize(string)] @@ -136,13 +161,13 @@ def tokenize(string): toks[i] = vals[toks[i]] # find if there exists !=, >=, <= - eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] + eq_idxs = [idx for idx, tok in enumerate(toks) if tok == '='] eq_idxs.reverse() prefix = ('!', '>', '<') for eq_idx in eq_idxs: - pre_tok = toks[eq_idx-1] + pre_tok = toks[eq_idx - 1] if pre_tok in prefix: - toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] + toks = toks[: eq_idx - 1] + [pre_tok + '='] + toks[eq_idx + 1 :] return toks @@ -152,45 +177,47 @@ def scan_alias(toks): as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] alias = {} for idx in as_idxs: - alias[toks[idx+1]] = toks[idx-1] + alias[toks[idx + 1]] = toks[idx - 1] return alias def get_tables_with_alias(schema, toks): tables = scan_alias(toks) for key in schema: - assert key not in tables, "Alias {} has the same name in table".format(key) + assert key not in tables, 'Alias {} has the same name in table'.format(key) tables[key] = key return tables def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): """ - :returns next idx, column id + :returns next idx, column id """ tok = toks[start_idx] - if tok == "*": + if tok == '*': return start_idx + 1, schema.idMap[tok] if '.' in tok: # if token is a composite alias, col = tok.split('.') - key = tables_with_alias[alias] + "." + col - return start_idx+1, schema.idMap[key] + key = tables_with_alias[alias] + '.' + col + return start_idx + 1, schema.idMap[key] - assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" + assert ( + default_tables is not None and len(default_tables) > 0 + ), 'Default tables should not be None or empty' for alias in default_tables: table = tables_with_alias[alias] if tok in schema.schema[table]: - key = table + "." + tok - return start_idx+1, schema.idMap[key] + key = table + '.' + tok + return start_idx + 1, schema.idMap[key] - assert False, "Error col: {}".format(tok) + assert False, 'Error col: {}'.format(tok) def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): """ - :returns next idx, (agg_op id, col_id) + :returns next idx, (agg_op id, col_id) """ idx = start_idx len_ = len(toks) @@ -205,7 +232,7 @@ def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=No idx += 1 assert idx < len_ and toks[idx] == '(' idx += 1 - if toks[idx] == "distinct": + if toks[idx] == 'distinct': idx += 1 isDistinct = True idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) @@ -213,10 +240,10 @@ def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=No idx += 1 return idx, (agg_id, col_id, isDistinct) - if toks[idx] == "distinct": + if toks[idx] == 'distinct': idx += 1 isDistinct = True - agg_id = AGG_OPS.index("none") + agg_id = AGG_OPS.index('none') idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) if isBlock: @@ -242,7 +269,9 @@ def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=No if idx < len_ and toks[idx] in UNIT_OPS: unit_op = UNIT_OPS.index(toks[idx]) idx += 1 - idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + idx, col_unit2 = parse_col_unit( + toks, idx, tables_with_alias, schema, default_tables + ) if isBlock: assert toks[idx] == ')' @@ -253,13 +282,13 @@ def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=No def parse_table_unit(toks, start_idx, tables_with_alias, schema): """ - :returns next idx, table id, table name + :returns next idx, table id, table name """ idx = start_idx len_ = len(toks) key = tables_with_alias[toks[idx]] - if idx + 1 < len_ and toks[idx+1] == "as": + if idx + 1 < len_ and toks[idx + 1] == 'as': idx += 3 else: idx += 1 @@ -278,20 +307,28 @@ def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None) if toks[idx] == 'select': idx, val = parse_sql(toks, idx, tables_with_alias, schema) - elif "\"" in toks[idx]: # token is a string value + elif '"' in toks[idx]: # token is a string value val = toks[idx] idx += 1 else: try: val = float(toks[idx]) idx += 1 - except: + except Exception: end_idx = idx - while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ - and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: - end_idx += 1 - - idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) + while ( + end_idx < len_ + and toks[end_idx] != ',' + and toks[end_idx] != ')' + and toks[end_idx] != 'and' + and toks[end_idx] not in CLAUSE_KEYWORDS + and toks[end_idx] not in JOIN_KEYWORDS + ): + end_idx += 1 + + idx, val = parse_col_unit( + toks[start_idx:end_idx], 0, tables_with_alias, schema, default_tables + ) idx = end_idx if isBlock: @@ -307,17 +344,23 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N conds = [] while idx < len_: - idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + idx, val_unit = parse_val_unit( + toks, idx, tables_with_alias, schema, default_tables + ) not_op = False if toks[idx] == 'not': not_op = True idx += 1 - assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) + assert ( + idx < len_ and toks[idx] in WHERE_OPS + ), 'Error condition: idx: {}, tok: {}'.format(idx, toks[idx]) op_id = WHERE_OPS.index(toks[idx]) idx += 1 val1 = val2 = None - if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values + if op_id == WHERE_OPS.index( + 'between' + ): # between..and... special case: dual values idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) assert toks[idx] == 'and' idx += 1 @@ -328,7 +371,11 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N conds.append((not_op, op_id, val_unit, val1, val2)) - if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): + if idx < len_ and ( + toks[idx] in CLAUSE_KEYWORDS + or toks[idx] in (')', ';') + or toks[idx] in JOIN_KEYWORDS + ): break if idx < len_ and toks[idx] in COND_OPS: @@ -351,11 +398,13 @@ def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None val_units = [] while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: - agg_id = AGG_OPS.index("none") + agg_id = AGG_OPS.index('none') if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 - idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + idx, val_unit = parse_val_unit( + toks, idx, tables_with_alias, schema, default_tables + ) val_units.append((agg_id, val_unit)) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' @@ -387,12 +436,16 @@ def parse_from(toks, start_idx, tables_with_alias, schema): else: if idx < len_ and toks[idx] == 'join': idx += 1 # skip join - idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) - table_units.append((TABLE_TYPE['table_unit'],table_unit)) + idx, table_unit, table_name = parse_table_unit( + toks, idx, tables_with_alias, schema + ) + table_units.append((TABLE_TYPE['table_unit'], table_unit)) default_tables.append(table_name) - if idx < len_ and toks[idx] == "on": + if idx < len_ and toks[idx] == 'on': idx += 1 # skip on - idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + idx, this_conds = parse_condition( + toks, idx, tables_with_alias, schema, default_tables + ) if len(conds) > 0: conds.append('and') conds.extend(this_conds) @@ -400,7 +453,7 @@ def parse_from(toks, start_idx, tables_with_alias, schema): if isBlock: assert toks[idx] == ')' idx += 1 - if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (')', ';')): break return idx, table_units, conds, default_tables @@ -430,8 +483,10 @@ def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): assert toks[idx] == 'by' idx += 1 - while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): - idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (')', ';')): + idx, col_unit = parse_col_unit( + toks, idx, tables_with_alias, schema, default_tables + ) col_units.append(col_unit) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' @@ -445,7 +500,7 @@ def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) val_units = [] - order_type = 'asc' # default type is 'asc' + order_type = 'asc' # default type is 'asc' if idx >= len_ or toks[idx] != 'order': return idx, val_units @@ -454,8 +509,10 @@ def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): assert toks[idx] == 'by' idx += 1 - while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): - idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (')', ';')): + idx, val_unit = parse_val_unit( + toks, idx, tables_with_alias, schema, default_tables + ) val_units.append(val_unit) if idx < len_ and toks[idx] in ORDER_OPS: order_type = toks[idx] @@ -486,13 +543,13 @@ def parse_limit(toks, start_idx): if idx < len_ and toks[idx] == 'limit': idx += 2 - return idx, int(toks[idx-1]) + return idx, int(toks[idx - 1]) return idx, None def parse_sql(toks, start_idx, tables_with_alias, schema): - isBlock = False # indicate whether this is a block of sql/sub-sql + isBlock = False # indicate whether this is a block of sql/sub-sql len_ = len(toks) idx = start_idx @@ -502,23 +559,31 @@ def parse_sql(toks, start_idx, tables_with_alias, schema): idx += 1 # parse from clause in order to get default tables - from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) + from_end_idx, table_units, conds, default_tables = parse_from( + toks, start_idx, tables_with_alias, schema + ) sql['from'] = {'table_units': table_units, 'conds': conds} # select clause - _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) + _, select_col_units = parse_select( + toks, idx, tables_with_alias, schema, default_tables + ) idx = from_end_idx sql['select'] = select_col_units # where clause idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) sql['where'] = where_conds # group by clause - idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) + idx, group_col_units = parse_group_by( + toks, idx, tables_with_alias, schema, default_tables + ) sql['groupBy'] = group_col_units # having clause idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) sql['having'] = having_conds # order by clause - idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) + idx, order_col_units = parse_order_by( + toks, idx, tables_with_alias, schema, default_tables + ) sql['orderBy'] = order_col_units # limit clause idx, limit_val = parse_limit(toks, idx) @@ -557,6 +622,6 @@ def get_sql(schema, query): def skip_semicolon(toks, start_idx): idx = start_idx - while idx < len(toks) and toks[idx] == ";": + while idx < len(toks) and toks[idx] == ';': idx += 1 return idx diff --git a/benchmark/benchmark_inference.py b/benchmark/benchmark_inference.py index 2e28d028..5e4974aa 100644 --- a/benchmark/benchmark_inference.py +++ b/benchmark/benchmark_inference.py @@ -120,7 +120,6 @@ def main(): guide = EarleyBoolMaskCFGLM(character_cfg) - MAX_TOKENS = 100 BATCH_SIZE = 80 hfppl_llm.batch_size = BATCH_SIZE @@ -137,7 +136,6 @@ def main(): else: raise ValueError(f'invalid proposal name {args.proposal!r}') - Particles = [] for _ in range(args.reps): for sql_prompt in prompts: prompt = prompt_template % sql_prompt diff --git a/benchmark/sql_parsing_speed.py b/benchmark/sql_parsing_speed.py index c646ed1c..08864cc6 100644 --- a/benchmark/sql_parsing_speed.py +++ b/benchmark/sql_parsing_speed.py @@ -2,24 +2,24 @@ import genparse import pylab as pl import argparse -import logging from time import time from arsenal import iterview, timers, timeit, colors from arsenal.iterextras import unique from genparse.segmentation import prefixes -#from genparse.cfglm import EarleyBoolMaskCFGLM +# from genparse.cfglm import EarleyBoolMaskCFGLM from genparse.util import LarkStuff -from genparse.cfglm import CFGLM from genparse.experimental.earley import EarleyLM def load_examples(example_path): - return unique(map(str.strip, open(example_path, 'r'))) # XXX: why are there duplicates? + return unique( + map(str.strip, open(example_path, 'r')) + ) # XXX: why are there duplicates? -def main(): +def main(): root = Path(genparse.__file__).dirname() / '..' parser = argparse.ArgumentParser( @@ -44,7 +44,7 @@ def main(): with timeit('preprocessing'): cfg = LarkStuff(open(args.grammar).read()).char_cfg(0.9, ignore='[ ]?') guide['earley'] = EarleyLM(cfg) -# guide['cfglm'] = CFGLM(cfg) + # guide['cfglm'] = CFGLM(cfg) T = timers() @@ -56,12 +56,12 @@ def main(): guide[name].clear_cache() for prefix in prefixes(example): - for name in guide: with T[name](n=len(prefix)): p = guide[name].p_next(prefix) - if not p: print(colors.light.red % f'FAILED {i}: {prefix}') + if not p: + print(colors.light.red % f'FAILED {i}: {prefix}') print('total time:', time() - start, 'seconds') diff --git a/genparse/cfg.py b/genparse/cfg.py index 6265eb1e..6be78318 100644 --- a/genparse/cfg.py +++ b/genparse/cfg.py @@ -95,7 +95,19 @@ def _repr_html_(self): class CFG: - def __init__(self, R: 'semiring', S: 'start symbol', V: 'terminal vocabulary'): # type: ignore + """ + Weighted Context-free Grammar + + R: semiring + S: start symbol + V: terminal vocabulary + + N: nonterminal set + rules: set of weighted rules (`list[Rule]`). + + """ + + def __init__(self, R, S, V): self.R = R # semiring self.V = V # alphabet self.N = {S} # nonterminals diff --git a/genparse/experimental/earley.py b/genparse/experimental/earley.py index fbcb1886..399fc9c8 100644 --- a/genparse/experimental/earley.py +++ b/genparse/experimental/earley.py @@ -1,10 +1,9 @@ import numpy as np -from arsenal import Integerizer, colors +from arsenal import Integerizer from collections import defaultdict -from functools import lru_cache -#from arsenal.datastructures.pdict import pdict +# from arsenal.datastructures.pdict import pdict from arsenal.datastructures.heap import LocatorMaxHeap from genparse.cfglm import EOS, add_EOS @@ -14,7 +13,6 @@ class EarleyLM(LM): - def __init__(self, cfg): if EOS not in cfg.V: cfg = add_EOS(cfg) @@ -33,7 +31,7 @@ def clear_cache(self): class Column: - __slots__ = ("k", "i_chart", "c_chart", "waiting_for", "Q") + __slots__ = ('k', 'i_chart', 'c_chart', 'waiting_for', 'Q') def __init__(self, k): self.k = k @@ -45,7 +43,7 @@ def __init__(self, k): self.waiting_for = defaultdict(set) # priority queue used when first filling the column -# self.Q = pdict() + # self.Q = pdict() self.Q = LocatorMaxHeap() @@ -55,11 +53,23 @@ class Earley: Warning: Assumes that nullary rules and unary chain cycles have been removed """ - __slots__ = ("cfg", "order", "_chart", "V", "eos", "_initial_column", "R", 'rhs', - 'ORDER_MAX', 'intern_Ys', 'unit_Ys', 'first_Ys', 'rest_Ys') + __slots__ = ( + 'cfg', + 'order', + '_chart', + 'V', + 'eos', + '_initial_column', + 'R', + 'rhs', + 'ORDER_MAX', + 'intern_Ys', + 'unit_Ys', + 'first_Ys', + 'rest_Ys', + ) def __init__(self, cfg): - cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber() self.cfg = cfg @@ -96,7 +106,8 @@ def __init__(self, cfg): for X in self.cfg.N: self.rhs[X] = [] for r in self.cfg.rhs[X]: - if r.body == (): continue + if r.body == (): + continue self.rhs[X].append((r.w, intern_Ys(r.body))) self.first_Ys = np.zeros(len(intern_Ys), dtype=object) @@ -104,7 +115,7 @@ def __init__(self, cfg): self.unit_Ys = np.zeros(len(intern_Ys), dtype=int) for Ys, code in list(self.intern_Ys.items()): - self.unit_Ys[code] = (len(Ys) == 1) + self.unit_Ys[code] = len(Ys) == 1 if len(Ys) > 0: self.first_Ys[code] = Ys[0] self.rest_Ys[code] = intern_Ys(Ys[1:]) @@ -152,7 +163,6 @@ def p_next(self, prefix): return self.next_token_weights(self.chart(prefix)) def next_column(self, prev_cols, token): - prev_col = prev_cols[-1] next_col = Column(prev_cols[-1].k + 1) next_col_c_chart = next_col.c_chart @@ -181,8 +191,6 @@ def next_column(self, prev_cols, token): def PREDICT(self, col): # PREDICT: phrase(K, X/Ys, K) += rule(X -> Ys) with some filtering heuristics k = col.k - col_chart = col.i_chart - col_waiting_for = col.waiting_for # Filtering heuristic: Don't create the predicted item (K, X, [...], K) # unless there exists an item that wants the X item that it may @@ -277,7 +285,7 @@ def next_token_weights(self, cols): for Y in col_waiting_for: if is_terminal(Y): total = zero - for (I, X, Ys) in col_waiting_for[Y]: + for I, X, Ys in col_waiting_for[Y]: if self.unit_Ys[Ys]: node = (I, X) value = self._helper(node, cols, q) @@ -287,7 +295,6 @@ def next_token_weights(self, cols): return p def _helper(self, top, cols, q): - value = q.get(top) if value is not None: return value @@ -296,7 +303,7 @@ def _helper(self, top, cols, q): stack = [Node(top, None, zero)] while stack: - node = stack[-1] # 👀 + node = stack[-1] # 👀 # place neighbors above the node on the stack (J, Y) = node.node @@ -330,6 +337,7 @@ def _helper(self, top, cols, q): class Node: __slots__ = ('value', 'node', 'edges', 'cursor') + def __init__(self, node, edges, value): self.node = node self.edges = edges diff --git a/genparse/experimental/earley_rescaled.py b/genparse/experimental/earley_rescaled.py index fdf4bb87..ecc5af55 100644 --- a/genparse/experimental/earley_rescaled.py +++ b/genparse/experimental/earley_rescaled.py @@ -143,8 +143,6 @@ def next_column(self, prev_cols, token): self.PREDICT(next_col) - k = next_col.k - num = prev_col.chart[0, self.cfg.S] den = next_col.chart[0, self.cfg.S] diff --git a/genparse/lm.py b/genparse/lm.py index 3b9ea360..e6a616ef 100644 --- a/genparse/lm.py +++ b/genparse/lm.py @@ -5,6 +5,7 @@ import numpy as np import torch from arsenal.maths import sample_dict +from functools import lru_cache from transformers import AutoModelForCausalLM, AutoTokenizer from genparse.semiring import Float @@ -309,8 +310,6 @@ def __repr__(self): return repr(self.materialize()) -from functools import lru_cache - @lru_cache(None) def make_mock_llm(**kwargs): from genparse.util import hf_tokenizer diff --git a/genparse/proposal/trie_numba.py b/genparse/proposal/trie_numba.py index 83329b1d..944bf181 100644 --- a/genparse/proposal/trie_numba.py +++ b/genparse/proposal/trie_numba.py @@ -1,7 +1,6 @@ import numba import numpy as np - -# from typing import Dict, List +from numba.typed import List class TokenCharacterTrie: @@ -119,9 +118,6 @@ def _order_full(self, node): yield node -from numba.typed import List - - @numba.jit(nopython=True) def _update_trie_numba( mass: numba.float64[:], diff --git a/genparse/record.py b/genparse/record.py index fd1007c4..c8b13584 100644 --- a/genparse/record.py +++ b/genparse/record.py @@ -225,8 +225,8 @@ def plotlyx(self, xrange=None, opts=dict(), layout_opts=dict()): ) # Get resample or no-resample steps, add vline on the resample ones - resample_steps = d_[d_['resample?'] == True]['step'].unique() - no_resample_steps = d_[d_['resample?'] == False]['step'].unique() + resample_steps = d_[d_['resample?']]['step'].unique() + no_resample_steps = d_[~d_['resample?']]['step'].unique() for step in resample_steps: fig.add_vline(x=step, line_width=4, opacity=0.15, line_color='gray') diff --git a/genparse/steer.py b/genparse/steer.py index 006349b9..16bfa814 100644 --- a/genparse/steer.py +++ b/genparse/steer.py @@ -11,7 +11,9 @@ import transformers from arsenal.maths import logsumexp, sample_dict -from genparse.cfglm import EOS +from hfppl import Model + +from genparse import EOS from genparse.inference import ( TraceSWOR, importance_sampling, @@ -107,7 +109,7 @@ def __call__(self, ys): p *= self.p_next(ys[:t])[ys[t]] return p - def p_next(self, prefix): + def p_next(self, ys): p1 = self.lm1.p_next(ys) p2 = self.lm2.p_next(ys) @@ -204,10 +206,6 @@ def __repr__(self): # This code is still experimental and actively being developed # TODO: write tests -from hfppl import Model - -from genparse import EOS - class HFPPLParticle(Model): """ diff --git a/genparse/tokenization.py b/genparse/tokenization.py index b1703e75..5ed2121b 100644 --- a/genparse/tokenization.py +++ b/genparse/tokenization.py @@ -8,6 +8,18 @@ from typing import Dict, List +def ints2bytes(sequence: List[int]) -> bytes: + # check in the range of 0-255 + for item in sequence: + if not 0 <= item <= 255: + raise ValueError(f'item: {item} is not in the range [0, 255]') + return bytes(sequence) + + +def bytes2ints(byte_sequence: bytes) -> List[int]: + return list(byte_sequence) + + def get_tokenizer_mapping(tokenizer): """ Very similar to get_mapping in transformers_cfg.tokenization.mapping @@ -65,8 +77,6 @@ def _map(self, token_id: int) -> str: def map(self, token_id: int, verbose=False) -> bytes: token = self._map(token_id) - if verbose: - log.debug(f'token_id: {token_id}, token: {token}') return bytes(token, 'utf-8') @@ -100,8 +110,6 @@ def _map(self, token_id: int, verbose=False) -> str: def map(self, token_id: int, verbose=False) -> bytes: raw_token = self._map(token_id, verbose) - if verbose: - log.debug(f'token_id: {token_id}, raw_token: {raw_token}') return self.intermediate_encoding.token2bytes(raw_token) @staticmethod @@ -188,6 +196,8 @@ class ByteEncoding: def __init__(self, tokenizer): # check if the tokenizer is fast, if so, convert it to slow if tokenizer.is_fast: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( tokenizer.name_or_path, use_fast=False ) diff --git a/notes/LM-Fun.ipynb b/notes/LM-Fun.ipynb deleted file mode 100644 index 5afecef3..00000000 --- a/notes/LM-Fun.ipynb +++ /dev/null @@ -1,343 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "ba3f7044-7797-46ab-a98e-e36962f9ef2a", - "metadata": {}, - "outputs": [], - "source": [ - "from genparse.lm import LLM, GreedilyTokenizedLLM\n", - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "from arsenal.maths import sample\n", - "from collections import Counter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6bcbac34-a10d-4004-9968-3a16eb34dcc5", - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained('gpt2')" - ] - }, - { - "cell_type": "markdown", - "id": "4fdb8aaa-8a8c-4f24-a772-111c6ecad167", - "metadata": {}, - "source": [ - "Note that the tokenizer might have multiple tokens that map to the same string." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8b074a8-3409-4f00-8e87-2cc5da2803a8", - "metadata": {}, - "outputs": [], - "source": [ - "# Counter(tokenizer.decode([k]) for k in range(tokenizer.vocab_size)).most_common(20)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03a0efa3-b6e5-416c-9b99-4de01f18f393", - "metadata": {}, - "outputs": [], - "source": [ - "lm = LLM(AutoModelForCausalLM.from_pretrained('gpt2'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b55478f1-91c3-4cfa-b874-321165ca8ecc", - "metadata": {}, - "outputs": [], - "source": [ - "lm([0, 1, 2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32c1dd70-f983-4bae-930f-d955ce4a5161", - "metadata": {}, - "outputs": [], - "source": [ - "lm.p_next([0, 1, 2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6afc295b-b212-443f-b8ec-0d4a351fd8e7", - "metadata": {}, - "outputs": [], - "source": [ - "lm.p_next([0, 1, 2, 3])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "443b9174-a20b-4887-929e-c28390fd3990", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da8bf971-2628-4a4a-8bfc-d417000fce79", - "metadata": {}, - "outputs": [], - "source": [ - "p = GreedilyTokenizedLLM('gpt2').p_next('Once upon a time,')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d90506e8-44c8-4b16-806f-92d7c9a95f60", - "metadata": {}, - "outputs": [], - "source": [ - "p = GreedilyTokenizedLLM('gpt2').p_next(\n", - " 'The following is some code that implements quick sort in Python:'\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d13a0edd-b4de-4899-a72e-3c9cac8ea17a", - "metadata": {}, - "outputs": [], - "source": [ - "pd.DataFrame(p.items()).set_index(0).sort_values(1, ascending=False).head(10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "356ef417-d1ca-4aab-bf59-f6e3a5a7f386", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f472a36b-25aa-4f54-a450-a7b9ffe272c0", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c17f77f2-3f34-4278-9069-f1c137ad24d0", - "metadata": {}, - "outputs": [], - "source": [ - "M = GreedilyTokenizedLLM('gpt2')\n", - "# .p_terminal('The following is some code that implements quick sort in Python:')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e98527f-4397-433e-8c12-ad7aab79df2c", - "metadata": {}, - "outputs": [], - "source": [ - "class Particle:\n", - " def __init__(self, xs):\n", - " self.xs = xs\n", - "\n", - " def __eq__(self, other):\n", - " return isinstance(other, Particle) and self.xs == other.xs\n", - "\n", - " def __hash__(self):\n", - " return hash(self.xs)\n", - "\n", - " def __repr__(self):\n", - " return f'{self.xs}'\n", - "\n", - " def p(self):\n", - " P = M.p_next(self.xs)\n", - " return Particles({Particle(self.xs + x): w for x, w in P.items()})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "507bb1fc-3a60-45ee-ab56-67fc76251f36", - "metadata": {}, - "outputs": [], - "source": [ - "class Particles(dict):\n", - " def __repr__(self):\n", - " return repr(Counter(self).most_common(10))\n", - "\n", - " def sample(self):\n", - " ks, ws = np.array(list(self.keys())), np.array(list(self.values()))\n", - " # ws[np.argsort(-ws)[50:]] = 0\n", - " # return ks[sample(softmax(ws))]\n", - " return ks[sample(ws)]\n", - "\n", - " def greedy(self):\n", - " ks, ws = list(self.keys()), np.array(list(self.values()))\n", - " return ks[ws.argmax()]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "baab553a-d1d7-4ca7-a369-e2bda044367a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fe3a1d60-4184-47f9-9ac3-8a04a8bea8ec", - "metadata": {}, - "outputs": [], - "source": [ - "Particle(\n", - " 'The following is some code that implements the quick sort algorithm in Python:'\n", - ").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f3fffc5c-6d0a-477b-85d3-518dcea6edcc", - "metadata": {}, - "outputs": [], - "source": [ - "Particle(\n", - " 'Once upon a time'\n", - ").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "51ccd3c4-405c-4f1b-b4b3-f7eecf4b80d3", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3ce6b2da-ecaf-4dba-a8ab-9c5ce0b428c2", - "metadata": {}, - "outputs": [], - "source": [ - "p = Particle('Once upon a time,')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6cca1245-818e-44cb-9bcc-a20362c4a623", - "metadata": {}, - "outputs": [], - "source": [ - "p.p()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30f58a29-2cf3-49b2-8abb-b3353190a64d", - "metadata": {}, - "outputs": [], - "source": [ - "p.p().sample()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00573763-8e4c-43d5-ab26-a3d1a07ba0b0", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " p.p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " .p()\n", - " .greedy()\n", - " # .p().greedy().p().greedy().p().greedy().p().greedy().p().greedy().p().greedy()\n", - " # .p().greedy().p().greedy().p().greedy().p().greedy().p().greedy().p().greedy()\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a564a9d2-6a4e-4626-88d9-fc025436b677", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec332a04-ee19-440a-9294-bbd360775a95", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notes/benchmark_hfppl.py b/notes/benchmark_hfppl.py index a39433ac..f441097c 100644 --- a/notes/benchmark_hfppl.py +++ b/notes/benchmark_hfppl.py @@ -1,7 +1,4 @@ import argparse -import getpass -import os -import sys PROMPT = """ You have access to a political survey data table named "data", which includes the following columns: @@ -39,18 +36,9 @@ """ -def set_environment(): - if getpass.getuser() == 'benjamin.lebrun': - sys.path.append('/home/mila/b/benjamin.lebrun/genparse') - os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache') - print('HF cache set; path updated') - - def main( model_name, proposal_name, batch_size, n_particles, method, max_tokens, verbosity ): - set_environment() - from genparse.cfglm import EarleyBoolMaskCFGLM from genparse.lm import AsyncGreedilyTokenizedLLM from genparse.proposal import CharacterProposal, TokenProposal diff --git a/notes/fst_pruned_composition.py b/notes/fst_pruned_composition.py index af97b338..f5bde51d 100644 --- a/notes/fst_pruned_composition.py +++ b/notes/fst_pruned_composition.py @@ -1,5 +1,5 @@ +import numpy as np from time import time - from arsenal import colors, timers from genparse import Float @@ -310,7 +310,7 @@ def fidelity(a2b, b2c, keep, filter_time): # perfectly trimmed states T = fine.trim.states - precision = len(F & T) / len(F) if len(F) > 0 else 1 + # precision = len(F & T) / len(F) if len(F) > 0 else 1 recall = len(F & T) / len(T) if len(T) > 0 else 1 print() @@ -370,7 +370,6 @@ def __call__(self, state): # XXX: Be careful - epsilon cannot be merged like the other labels (that should # probably be enforced in the coarsen method). -import numpy as np def random_hash(domain): diff --git a/notes/grammar_processing_issues.ipynb b/notes/grammar_processing_issues.ipynb index ff6c5fcc..73fbf6ff 100644 --- a/notes/grammar_processing_issues.ipynb +++ b/notes/grammar_processing_issues.ipynb @@ -11,20 +11,6 @@ "%autoreload 2" ] }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9851ed1e", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import getpass\n", - "\n", - "if getpass.getuser() == 'benjamin.lebrun':\n", - " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')" - ] - }, { "cell_type": "code", "execution_count": 8, diff --git a/notes/hfppl.ipynb b/notes/hfppl.ipynb index 312b8109..4decd94a 100644 --- a/notes/hfppl.ipynb +++ b/notes/hfppl.ipynb @@ -307,6 +307,7 @@ "metadata": {}, "outputs": [], "source": [ + "from genparse.cfglm import EarleyBoolMaskCFGLM\n", "from genparse.proposal import CharacterProposal, TokenProposal\n", "from genparse.util import LarkStuff" ] diff --git a/notes/hfppl_benleb.ipynb b/notes/hfppl_benleb.ipynb index bf4cd5e1..e313ad16 100644 --- a/notes/hfppl_benleb.ipynb +++ b/notes/hfppl_benleb.ipynb @@ -26,15 +26,6 @@ } ], "source": [ - "import sys\n", - "import os\n", - "import getpass\n", - "\n", - "if getpass.getuser() == 'benjamin.lebrun':\n", - " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')\n", - " os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')\n", - " print('HF cache set; path updated')\n", - "\n", "import nest_asyncio\n", "\n", "nest_asyncio.apply()" @@ -438,7 +429,7 @@ ], "metadata": { "kernelspec": { - "display_name": "genparse", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -452,7 +443,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notes/sql_debug.ipynb b/notes/sql_debug.ipynb index dcb0112b..9e71e430 100644 --- a/notes/sql_debug.ipynb +++ b/notes/sql_debug.ipynb @@ -11,20 +11,6 @@ "%autoreload 2" ] }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9851ed1e", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import getpass\n", - "\n", - "if getpass.getuser() == 'benjamin.lebrun':\n", - " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -90,7 +76,7 @@ ], "metadata": { "kernelspec": { - "display_name": "genparse", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -104,7 +90,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notes/test_grammar_coverage.py b/notes/test_grammar_coverage.py index 99775007..38d95d0f 100644 --- a/notes/test_grammar_coverage.py +++ b/notes/test_grammar_coverage.py @@ -1,16 +1,10 @@ import argparse -import getpass import logging -import sys - -logger = logging.getLogger(__name__) - -if getpass.getuser() == 'benjamin.lebrun': - sys.path.append('/home/mila/b/benjamin.lebrun/genparse') - from genparse.cfglm import EarleyBoolMaskCFGLM from genparse.util import LarkStuff +logger = logging.getLogger(__name__) + def load_guide(grammar_name): cfg = LarkStuff(open(grammar_name).read()).char_cfg(0.99, ignore='[ ]?') diff --git a/ruff.toml b/ruff.toml index ef616296..8c71425e 100644 --- a/ruff.toml +++ b/ruff.toml @@ -17,6 +17,7 @@ exclude = [ "profile.html", "profile.json", "*.log", + "bench/spider-eval" ] line-length = 90 @@ -26,7 +27,7 @@ extend-include = ["*.ipynb"] [lint] select = ["E4", "E7", "E9", "F"] -ignore = [] +ignore = ["E731", "E701", "F401", "E741", "E743"] fixable = ["ALL"] unfixable = [] diff --git a/tests/test_inference.py b/tests/test_inference.py index 0cfc5f85..aedfbc4c 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -51,8 +51,6 @@ def run_test(lm1, lm2): # some diagnostics to characterize the quality of the approximation. class CheckParticles(BruteForceGlobalProductOfExperts): def check(self, particles): - n_particles = len(particles) - # TODO: weight finalization should be part of the inference algorithm! w = Float.chart() for p in particles: diff --git a/tests/test_wcfg.py b/tests/test_wcfg.py index a421f80f..b0168ae9 100644 --- a/tests/test_wcfg.py +++ b/tests/test_wcfg.py @@ -359,7 +359,7 @@ def test_palindrome_derivations(): Real, ) - s = 'a b c b a'.split() + # s = 'a b c b a'.split() n = 0 print(colors.yellow % 'Derivations:') diff --git a/tests/test_wfsa_field.py b/tests/test_wfsa_field.py index ed6ea898..13c8c2bd 100644 --- a/tests/test_wfsa_field.py +++ b/tests/test_wfsa_field.py @@ -12,7 +12,7 @@ def test_misc(): b = WFSA.lift('b', 1) c = WFSA.lift('c', 1) - M = one + a + a * b + a * b * c + M = zero + one + a + a * b + a * b * c # dry run M.graphviz() From d31399fc671fdd3e9009117d2bdc0844d152572f Mon Sep 17 00:00:00 2001 From: benlipkin Date: Tue, 18 Jun 2024 16:40:26 -0400 Subject: [PATCH 5/5] Revert "linter ignore notes, bench, benchmark; only track genparse, tests" This reverts commit a6c3bcb7916999af2ca6a78ad0f3b18a43373fdb. --- genparse/experimental/earley.py | 40 +++++++++++++-------------------- genparse/lm.py | 1 - ruff.toml | 3 --- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/genparse/experimental/earley.py b/genparse/experimental/earley.py index 552b74ae..fbcb1886 100644 --- a/genparse/experimental/earley.py +++ b/genparse/experimental/earley.py @@ -1,9 +1,10 @@ import numpy as np -from arsenal import Integerizer +from arsenal import Integerizer, colors from collections import defaultdict +from functools import lru_cache -# from arsenal.datastructures.pdict import pdict +#from arsenal.datastructures.pdict import pdict from arsenal.datastructures.heap import LocatorMaxHeap from genparse.cfglm import EOS, add_EOS @@ -13,6 +14,7 @@ class EarleyLM(LM): + def __init__(self, cfg): if EOS not in cfg.V: cfg = add_EOS(cfg) @@ -31,7 +33,7 @@ def clear_cache(self): class Column: - __slots__ = ('k', 'i_chart', 'c_chart', 'waiting_for', 'Q') + __slots__ = ("k", "i_chart", "c_chart", "waiting_for", "Q") def __init__(self, k): self.k = k @@ -43,7 +45,7 @@ def __init__(self, k): self.waiting_for = defaultdict(set) # priority queue used when first filling the column - # self.Q = pdict() +# self.Q = pdict() self.Q = LocatorMaxHeap() @@ -53,23 +55,11 @@ class Earley: Warning: Assumes that nullary rules and unary chain cycles have been removed """ - __slots__ = ( - 'cfg', - 'order', - '_chart', - 'V', - 'eos', - '_initial_column', - 'R', - 'rhs', - 'ORDER_MAX', - 'intern_Ys', - 'unit_Ys', - 'first_Ys', - 'rest_Ys', - ) + __slots__ = ("cfg", "order", "_chart", "V", "eos", "_initial_column", "R", 'rhs', + 'ORDER_MAX', 'intern_Ys', 'unit_Ys', 'first_Ys', 'rest_Ys') def __init__(self, cfg): + cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber() self.cfg = cfg @@ -106,8 +96,7 @@ def __init__(self, cfg): for X in self.cfg.N: self.rhs[X] = [] for r in self.cfg.rhs[X]: - if r.body == (): - continue + if r.body == (): continue self.rhs[X].append((r.w, intern_Ys(r.body))) self.first_Ys = np.zeros(len(intern_Ys), dtype=object) @@ -115,7 +104,7 @@ def __init__(self, cfg): self.unit_Ys = np.zeros(len(intern_Ys), dtype=int) for Ys, code in list(self.intern_Ys.items()): - self.unit_Ys[code] = len(Ys) == 1 + self.unit_Ys[code] = (len(Ys) == 1) if len(Ys) > 0: self.first_Ys[code] = Ys[0] self.rest_Ys[code] = intern_Ys(Ys[1:]) @@ -163,6 +152,7 @@ def p_next(self, prefix): return self.next_token_weights(self.chart(prefix)) def next_column(self, prev_cols, token): + prev_col = prev_cols[-1] next_col = Column(prev_cols[-1].k + 1) next_col_c_chart = next_col.c_chart @@ -287,7 +277,7 @@ def next_token_weights(self, cols): for Y in col_waiting_for: if is_terminal(Y): total = zero - for I, X, Ys in col_waiting_for[Y]: + for (I, X, Ys) in col_waiting_for[Y]: if self.unit_Ys[Ys]: node = (I, X) value = self._helper(node, cols, q) @@ -297,6 +287,7 @@ def next_token_weights(self, cols): return p def _helper(self, top, cols, q): + value = q.get(top) if value is not None: return value @@ -305,7 +296,7 @@ def _helper(self, top, cols, q): stack = [Node(top, None, zero)] while stack: - node = stack[-1] # 👀 + node = stack[-1] # 👀 # place neighbors above the node on the stack (J, Y) = node.node @@ -339,7 +330,6 @@ def _helper(self, top, cols, q): class Node: __slots__ = ('value', 'node', 'edges', 'cursor') - def __init__(self, node, edges, value): self.node = node self.edges = edges diff --git a/genparse/lm.py b/genparse/lm.py index 86dbab3c..3b9ea360 100644 --- a/genparse/lm.py +++ b/genparse/lm.py @@ -311,7 +311,6 @@ def __repr__(self): from functools import lru_cache - @lru_cache(None) def make_mock_llm(**kwargs): from genparse.util import hf_tokenizer diff --git a/ruff.toml b/ruff.toml index 86adfdb4..ef616296 100644 --- a/ruff.toml +++ b/ruff.toml @@ -17,9 +17,6 @@ exclude = [ "profile.html", "profile.json", "*.log", - "notes", - "bench", - "benchmark", ] line-length = 90