Skip to content

Commit

Permalink
fix ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 30, 2024
1 parent b48f3ca commit aa253bd
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 48 deletions.
74 changes: 49 additions & 25 deletions benchmark/Spider.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"from genparse import EOS\n",
"\n",
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply()"
]
},
Expand All @@ -44,7 +45,7 @@
"metadata": {},
"outputs": [],
"source": [
"#for x in spider.dev_data[:2]:\n",
"# for x in spider.dev_data[:2]:\n",
"# print()\n",
"# print(x.text)\n",
"# print(x.gold_sql)\n",
Expand Down Expand Up @@ -169,13 +170,14 @@
"\n",
"import bench.spider.evaluation as E\n",
"from bench.spider.evaluation import (\n",
" #build_foreign_key_map_from_json,\n",
" # build_foreign_key_map_from_json,\n",
" build_valid_col_units,\n",
" rebuild_sql_val,\n",
" rebuild_sql_col,\n",
" eval_exec_match,\n",
")\n",
"\n",
"\n",
"def evaluate(self, gold: str, pred: str, db_name: str):\n",
" \"\"\"Returns: bool, Optional[str]\n",
"\n",
Expand All @@ -192,11 +194,15 @@
" p_sql = E.get_sql(schema, pred)\n",
" except Exception as e:\n",
" # sql is ill-formed (can't be parsed by sqlite engine)\n",
" print(colors.red % e.__class__.__name__, e, )\n",
" print(\n",
" colors.red % e.__class__.__name__,\n",
" e,\n",
" )\n",
"\n",
" import traceback\n",
"\n",
" traceback.print_exc()\n",
" \n",
"\n",
" return False, 'invalid'\n",
"\n",
" kmap = self.kmaps[db_name]\n",
Expand All @@ -212,7 +218,8 @@
"\n",
" return exec_match, reason\n",
"\n",
"#spider.evaluator = evaluate"
"\n",
"# spider.evaluator = evaluate"
]
},
{
Expand All @@ -222,7 +229,7 @@
"metadata": {},
"outputs": [],
"source": [
"#x.interface.evaluator = evaluate"
"# x.interface.evaluator = evaluate"
]
},
{
Expand Down Expand Up @@ -288,7 +295,9 @@
}
],
"source": [
"x.interface.evaluator.evaluate(x.interface.evaluator, gold = x.gold_sql, pred = junk_sql, db_name = x.db_name)"
"x.interface.evaluator.evaluate(\n",
" x.interface.evaluator, gold=x.gold_sql, pred=junk_sql, db_name=x.db_name\n",
")"
]
},
{
Expand All @@ -298,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"#x.db_schema.columns"
"# x.db_schema.columns"
]
},
{
Expand All @@ -308,7 +317,9 @@
"metadata": {},
"outputs": [],
"source": [
"grammar_text = open('/home/timv/projects/genparse/benchmark/grammars/sql_case_insensitive.lark').read()"
"grammar_text = open(\n",
" '/home/timv/projects/genparse/benchmark/grammars/sql_case_insensitive.lark'\n",
").read()"
]
},
{
Expand Down Expand Up @@ -347,8 +358,14 @@
}
],
"source": [
"#infer = InferenceSetup('codellama', grammar_text, proposal_name='character', guide_opts={'ignore': '\\s*'})\n",
"infer = InferenceSetupVLLM('codellama', grammar_text, proposal_name='character', guide_opts={'ignore': r'\\s*'}, batch_size=50)"
"# infer = InferenceSetup('codellama', grammar_text, proposal_name='character', guide_opts={'ignore': '\\s*'})\n",
"infer = InferenceSetupVLLM(\n",
" 'codellama',\n",
" grammar_text,\n",
" proposal_name='character',\n",
" guide_opts={'ignore': r'\\s*'},\n",
" batch_size=50,\n",
")"
]
},
{
Expand Down Expand Up @@ -438,7 +455,7 @@
"outputs": [],
"source": [
"def show_posterior_tables(x, p):\n",
" display(HTML(f'<h3>{x.text}</h3>')) \n",
" display(HTML(f'<h3>{x.text}</h3>'))\n",
" for y, py in sorted(p[0].posterior.items(), key=lambda ab: -ab[1]):\n",
" y = y[:-1] # remove EOS\n",
" print(f'{colors.mark(x.evaluate(y))} {py:-.6f} {y}')\n",
Expand Down Expand Up @@ -685,12 +702,14 @@
"outputs": [],
"source": [
"def EVAL(x, y1, y2):\n",
" if y1.endswith(EOS): y1 = y1[:-1] \n",
" if y2.endswith(EOS): y2 = y2[:-1]\n",
" if y1.endswith(EOS):\n",
" y1 = y1[:-1]\n",
" if y2.endswith(EOS):\n",
" y2 = y2[:-1]\n",
" try:\n",
" return x.interface.evaluate(y1, y2, x.db_name)[0]\n",
" except Exception:\n",
"# print(e)\n",
" # print(e)\n",
" return False"
]
},
Expand All @@ -702,8 +721,7 @@
"outputs": [],
"source": [
"def risk(x, candidate, particles):\n",
" return sum(p * EVAL(x, candidate, y) \n",
" for y, p in particles[0].posterior.items())"
" return sum(p * EVAL(x, candidate, y) for y, p in particles[0].posterior.items())"
]
},
{
Expand All @@ -714,17 +732,25 @@
"outputs": [],
"source": [
"def show_mbr_tables(x, particles):\n",
" display(HTML(f'<h3>{x.text}</h3>')) \n",
" risks = {candidate: risk(x, candidate, particles) for candidate in particles[0].posterior}\n",
" display(HTML(f'<h3>{x.text}</h3>'))\n",
" risks = {\n",
" candidate: risk(x, candidate, particles) for candidate in particles[0].posterior\n",
" }\n",
" for candidate in sorted(risks, key=risks.__getitem__, reverse=True):\n",
" print(colors.mark(x.evaluate(candidate[:-1] if candidate.endswith(EOS) else candidate)), \n",
" f'{risks[candidate]:f}', (colors.dark.magenta if risks[candidate] == 0 else colors.light.magenta) % candidate)\n",
" print(\n",
" colors.mark(\n",
" x.evaluate(candidate[:-1] if candidate.endswith(EOS) else candidate)\n",
" ),\n",
" f'{risks[candidate]:f}',\n",
" (colors.dark.magenta if risks[candidate] == 0 else colors.light.magenta)\n",
" % candidate,\n",
" )\n",
" try:\n",
" display(x.run_query(candidate[:-1] if candidate.endswith(EOS) else candidate))\n",
" except Exception as e:\n",
" print()\n",
" print(colors.dark.red % '💀 ERROR', e)\n",
" print()\n"
" print()"
]
},
{
Expand Down Expand Up @@ -878,9 +904,7 @@
"id": "4cbb631a",
"metadata": {},
"outputs": [],
"source": [
"\n"
]
"source": []
},
{
"cell_type": "code",
Expand Down
36 changes: 18 additions & 18 deletions benchmark/benchclamp_to_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,42 @@
from genparse.semiring import Real


def convert_rules(
benchclamp_rule_set
):
def convert_rules(benchclamp_rule_set):
lhs, rhs_list = benchclamp_rule_set
return [Rule(
Real(1), # weight
"nt_" + lhs,
convert_rhs(rhs)
)
for rhs in rhs_list
return [
Rule(
Real(1), # weight
'nt_' + lhs,
convert_rhs(rhs),
)
for rhs in rhs_list
]


def convert_rhs_token(token):
print(token)

name = token['underlying']
if token['optional']:
# haven't coded this yet, assuming this doesn't happen in the grammars
raise NotImplementedError
if token['type'] == "nonterminal":
if token['type'] == 'nonterminal':
name = 'nt_' + name
return name

def convert_rhs(
benchclamp_rhs
):

def convert_rhs(benchclamp_rhs):
if benchclamp_rhs == [{}]:
return ()
return tuple(convert_rhs_token(token) for token in benchclamp_rhs)


def make_cfg_from_rules(rules):
terminals = [s for r in rules for s in r.body if s[:3] != 'nt_']


cfg = CFG(
R=Real,
S="nt_start",
S='nt_start',
V=set(terminals),
)

Expand All @@ -48,8 +47,9 @@ def make_cfg_from_rules(rules):

return cfg

if __name__ == "__main__":
grammars = json.load(open("benchmark/grammars/benchclamp_spider_grammars.json", "r"))

if __name__ == '__main__':
grammars = json.load(open('benchmark/grammars/benchclamp_spider_grammars.json', 'r'))
grammar = grammars['perpetrator']
rules = [r for rule in grammar.items() for r in convert_rules(rule)]
cfg = make_cfg_from_rules(rules)
cfg = make_cfg_from_rules(rules)
2 changes: 1 addition & 1 deletion notes/Crunching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"source": [
"items = []\n",
"start = time()\n",
"for item in take(10, q.posterior_enumerate(depth = 10)):\n",
"for item in take(10, q.posterior_enumerate(depth=10)):\n",
" print()\n",
" print(item.ps, (colors.red % '·').join(item.ys[1:]))\n",
" items.append(item)\n",
Expand Down
4 changes: 3 additions & 1 deletion notes/FST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,9 @@
"metadata": {},
"outputs": [],
"source": [
"bpe_lm = TokenProposal(guide=lm, llm=MockLLM(V={x for _, x in H.pairs}, eos=H.tokenizer.eos_token))"
"bpe_lm = TokenProposal(\n",
" guide=lm, llm=MockLLM(V={x for _, x in H.pairs}, eos=H.tokenizer.eos_token)\n",
")"
]
},
{
Expand Down
13 changes: 10 additions & 3 deletions notes/Token-Alignment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"outputs": [],
"source": [
"p = CFGLM.from_string(\n",
"\"\"\"\n",
" \"\"\"\n",
"\n",
"1: S -> a\n",
"1: S -> a a\n",
Expand Down Expand Up @@ -1621,7 +1621,14 @@
],
"source": [
"display_table(\n",
" [[p.cfg.language(100).project(ϕ).sort(), L_PB.sort(), generation_tree(graft).D.sort(), PL.sort()]],\n",
" [\n",
" [\n",
" p.cfg.language(100).project(ϕ).sort(),\n",
" L_PB.sort(),\n",
" generation_tree(graft).D.sort(),\n",
" PL.sort(),\n",
" ]\n",
" ],\n",
" headings=['target', 'pfst', 'grafting-heuristic', 'wfst'],\n",
")"
]
Expand Down Expand Up @@ -1673,7 +1680,7 @@
" out = trace({'a': 0.5, 'b': 0.5})\n",
" if out == 'a':\n",
" out = trace({'a': 0.9, 'b': 0.1})\n",
" else: \n",
" else:\n",
" out = trace({'a': 0.01, 'b': 0.99})\n",
" print(out, trace.root.mass)"
]
Expand Down

0 comments on commit aa253bd

Please sign in to comment.