diff --git a/notes/A-Tale-of-Two-Transducers.ipynb b/notes/A-Tale-of-Two-Transducers.ipynb
index 2cd62071..c10f9ba7 100644
--- a/notes/A-Tale-of-Two-Transducers.ipynb
+++ b/notes/A-Tale-of-Two-Transducers.ipynb
@@ -192,7 +192,7 @@
}
],
"source": [
- "D = Delta(\"ab\")\n",
+ "D = Delta('ab')\n",
"D"
]
},
@@ -331,7 +331,7 @@
}
],
"source": [
- "(D @ WFSA.from_string(\"ab\", Float)).trim"
+ "(D @ WFSA.from_string('ab', Float)).trim"
]
},
{
@@ -465,7 +465,7 @@
}
],
"source": [
- "(cfg @ Delta1(\"ab\")).nullaryremove(binarize=False).trim()"
+ "(cfg @ Delta1('ab')).nullaryremove(binarize=False).trim()"
]
},
{
@@ -664,7 +664,7 @@
}
],
"source": [
- "derivative(\"ab\", \"ab\").graphviz(fmt_node=lambda x: str(x))"
+ "derivative('ab', 'ab').graphviz(fmt_node=lambda x: str(x))"
]
},
{
@@ -688,7 +688,7 @@
}
],
"source": [
- "(cfg @ derivative(\"a\", \"ab\") @ derivative(\"a\", \"ab\")).language(10)"
+ "(cfg @ derivative('a', 'ab') @ derivative('a', 'ab')).language(10)"
]
},
{
@@ -712,7 +712,7 @@
}
],
"source": [
- "(cfg @ derivative(\"aa\", \"ab\")).language(10)"
+ "(cfg @ derivative('aa', 'ab')).language(10)"
]
},
{
@@ -748,7 +748,7 @@
}
],
"source": [
- "cfg.derivative(\"a\").trim()"
+ "cfg.derivative('a').trim()"
]
},
{
@@ -778,7 +778,7 @@
}
],
"source": [
- "(cfg @ derivative(\"a\", cfg.V)).nullaryremove(binarize=False).unaryremove().trim()"
+ "(cfg @ derivative('a', cfg.V)).nullaryremove(binarize=False).unaryremove().trim()"
]
},
{
diff --git a/notes/Character-at-a-Time.ipynb b/notes/Character-at-a-Time.ipynb
index e0a8b4b8..0be025d1 100644
--- a/notes/Character-at-a-Time.ipynb
+++ b/notes/Character-at-a-Time.ipynb
@@ -33,7 +33,7 @@
"metadata": {},
"outputs": [],
"source": [
- "llm = GreedilyTokenizedLLM(\"gpt2\")"
+ "llm = GreedilyTokenizedLLM('gpt2')"
]
},
{
@@ -57,7 +57,7 @@
}
],
"source": [
- "prompt = \"Hello my name is\"\n",
+ "prompt = 'Hello my name is'\n",
"pp = llm.p_next(prompt, 10).normalize()\n",
"pp"
]
@@ -174,7 +174,7 @@
}
],
"source": [
- "print(repr(\"\".join(pcfg.sample())))"
+ "print(repr(''.join(pcfg.sample())))"
]
},
{
@@ -206,7 +206,7 @@
"tracer = TraceSWOR()\n",
"for _ in range(1):\n",
" with tracer:\n",
- " print(\"----------------------------------\")\n",
+ " print('----------------------------------')\n",
" ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer, verbosity=1)\n",
" print(ys)"
]
@@ -1486,7 +1486,7 @@
],
"source": [
"tracer.root.graphviz(\n",
- " fmt_node=lambda x: f\"{x._mass:.3g}\", fmt_edge=lambda i, a, j: repr(a)\n",
+ " fmt_node=lambda x: f'{x._mass:.3g}', fmt_edge=lambda i, a, j: repr(a)\n",
")"
]
},
@@ -1524,7 +1524,7 @@
"ADJ: \"fruit\"\n",
"\n",
"\"\"\"\n",
- " ).char_cfg(0.99, ignore=\"[ ]?\"),\n",
+ " ).char_cfg(0.99, ignore='[ ]?'),\n",
" tol=1e-100,\n",
" )\n",
")"
@@ -1548,7 +1548,7 @@
}
],
"source": [
- "\"\".join(fruit.sample())"
+ "''.join(fruit.sample())"
]
},
{
@@ -1579,7 +1579,7 @@
}
],
"source": [
- "fruit(\"fruit flies like a banana \" + EOS)"
+ "fruit('fruit flies like a banana ' + EOS)"
]
},
{
@@ -1599,12 +1599,12 @@
}
],
"source": [
- "prompt = \"The following is a favorite sentence among linguists:\"\n",
+ "prompt = 'The following is a favorite sentence among linguists:'\n",
"token_trie_approx = TokenTrieApproximation(llm, fruit)\n",
"tracer = TraceSWOR()\n",
"for _ in range(1):\n",
" with tracer:\n",
- " print(\"----------------------------------\")\n",
+ " print('----------------------------------')\n",
" ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer)\n",
" print(ys)"
]
diff --git a/notes/FST.ipynb b/notes/FST.ipynb
index 322cc0db..2dcaaa18 100644
--- a/notes/FST.ipynb
+++ b/notes/FST.ipynb
@@ -155,7 +155,7 @@
"metadata": {},
"outputs": [],
"source": [
- "b2c = H.fst.prune_to_alphabet(None, foo.V | {\"\"}).renumber"
+ "b2c = H.fst.prune_to_alphabet(None, foo.V | {''}).renumber"
]
},
{
@@ -174,10 +174,10 @@
"outputs": [],
"source": [
"for x in foo.cnf.language(3):\n",
- " display(HTML(\"
\"))\n",
+ " display(HTML('
'))\n",
" print(x)\n",
" bpe_x = b2c(None, x).epsremove.trim\n",
- " print(\"total weight of BPE sequences (i.e., ambiguity):\", bpe_x.total_weight())\n",
+ " print('total weight of BPE sequences (i.e., ambiguity):', bpe_x.total_weight())\n",
" display(bpe_x)\n",
" print()"
]
@@ -253,7 +253,7 @@
"for x, w in foo.cnf.language(L + 2).items():\n",
" if len(x) > L:\n",
" continue\n",
- " cc[\"\".join(x)] += w\n",
+ " cc[''.join(x)] += w\n",
"# cc"
]
},
@@ -426,7 +426,7 @@
"df = []\n",
"for x, w in sorted(normalize(lm2.p_next(context)).items(), key=lambda kv: -kv[1]):\n",
" df.append((x, (H.tokenizer.decode([x]) if x != EOS else EOS), w))\n",
- "pd.DataFrame(df, columns=[\"token_id\", \"chars\", \"prob\"]).set_index(\"token_id\")"
+ "pd.DataFrame(df, columns=['token_id', 'chars', 'prob']).set_index('token_id')"
]
},
{
@@ -546,9 +546,9 @@
"source": [
"trace = TraceSWOR()\n",
"for _ in range(15):\n",
- " print(\"mass=\", trace.root.mass)\n",
+ " print('mass=', trace.root.mass)\n",
" with trace:\n",
- " print(\"\".join(lm.sample(draw=trace)))"
+ " print(''.join(lm.sample(draw=trace)))"
]
},
{
@@ -576,7 +576,7 @@
"metadata": {},
"outputs": [],
"source": [
- "c2t = lark_stuff.transducer(ignore=\"\", decay=0.0125)\n",
+ "c2t = lark_stuff.transducer(ignore='', decay=0.0125)\n",
"len(c2t.states)"
]
},
@@ -595,7 +595,7 @@
"metadata": {},
"outputs": [],
"source": [
- "x = \"SELECT * FROM data\""
+ "x = 'SELECT * FROM data'"
]
},
{
@@ -711,7 +711,7 @@
"metadata": {},
"outputs": [],
"source": [
- "cfg_t(\"SELECT * FROM data \")"
+ "cfg_t('SELECT * FROM data ')"
]
},
{
@@ -721,7 +721,7 @@
"metadata": {},
"outputs": [],
"source": [
- "cfg_t(\"SELECT * FROM data \")"
+ "cfg_t('SELECT * FROM data ')"
]
},
{
@@ -742,7 +742,7 @@
"outputs": [],
"source": [
"for _ in range(10):\n",
- " print(\"\".join(lm.sample()))"
+ " print(''.join(lm.sample()))"
]
},
{
@@ -752,7 +752,7 @@
"metadata": {},
"outputs": [],
"source": [
- "lm.p_next(\"SELECT * FROM \")"
+ "lm.p_next('SELECT * FROM ')"
]
},
{
@@ -800,7 +800,7 @@
"metadata": {},
"outputs": [],
"source": [
- "x = \"SELECT * FROM data\"\n",
+ "x = 'SELECT * FROM data'\n",
"b = tokenizer.encode(x)\n",
"b"
]
@@ -822,7 +822,7 @@
"metadata": {},
"outputs": [],
"source": [
- "with timeit(\"composition\"):\n",
+ "with timeit('composition'):\n",
" c = FST.from_string(tuple(b), Float) @ b2c\n",
"about(c)"
]
@@ -879,7 +879,7 @@
"metadata": {},
"outputs": [],
"source": [
- "x = x = \"SELECT * FROM data\""
+ "x = x = 'SELECT * FROM data'"
]
},
{
@@ -889,9 +889,9 @@
"metadata": {},
"outputs": [],
"source": [
- "with timeit(\"composition\"):\n",
+ "with timeit('composition'):\n",
" bs = b2c @ FST.from_string(x, Float)\n",
- "with timeit(\"trim\"):\n",
+ "with timeit('trim'):\n",
" bs.trim\n",
"about(bs)"
]
@@ -1025,7 +1025,7 @@
"metadata": {},
"outputs": [],
"source": [
- "print(\"\".join(lm.sample()))"
+ "print(''.join(lm.sample()))"
]
},
{
@@ -1053,9 +1053,7 @@
"metadata": {},
"outputs": [],
"source": [
- "bpe_lm = CharAlignedCFGLM(\n",
- " lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token\n",
- ")"
+ "bpe_lm = CharAlignedCFGLM(lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token)"
]
},
{
@@ -1065,7 +1063,7 @@
"metadata": {},
"outputs": [],
"source": [
- "lm.p_next(\"\")"
+ "lm.p_next('')"
]
},
{
@@ -1083,7 +1081,7 @@
"metadata": {},
"outputs": [],
"source": [
- "bpe_lm.p_next(\"\")"
+ "bpe_lm.p_next('')"
]
},
{
@@ -1093,7 +1091,7 @@
"metadata": {},
"outputs": [],
"source": [
- "lm.p_next(\"SELECT \")"
+ "lm.p_next('SELECT ')"
]
},
{
@@ -1103,7 +1101,7 @@
"metadata": {},
"outputs": [],
"source": [
- "bpe_lm.p_next(\"SELECT \")"
+ "bpe_lm.p_next('SELECT ')"
]
},
{
diff --git a/notes/Inference-Playground.ipynb b/notes/Inference-Playground.ipynb
index 48088b81..c93e0c13 100644
--- a/notes/Inference-Playground.ipynb
+++ b/notes/Inference-Playground.ipynb
@@ -142,7 +142,7 @@
],
"source": [
"ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)\n",
- "ref.target.project(\"\".join)"
+ "ref.target.project(''.join)"
]
},
{
@@ -165,7 +165,7 @@
" lm2,\n",
" MAX_LENGTH=MAX_LENGTH,\n",
" n_particles=N_PARTICLES,\n",
- " METHOD=\"is\",\n",
+ " METHOD='is',\n",
" # METHOD = 'smc',\n",
")"
]
@@ -258,8 +258,8 @@
}
],
"source": [
- "w.project(\"\".join).trim().compare(ref.target.project(\"\".join).trim()).sort_values(\n",
- " \"key\", ascending=False\n",
+ "w.project(''.join).trim().compare(ref.target.project(''.join).trim()).sort_values(\n",
+ " 'key', ascending=False\n",
")"
]
},
@@ -328,7 +328,7 @@
"# truncate the reference distribution to the support set of the sample;\n",
"# renamed the keys to handle the minor discrepancy in the EOS symbol\n",
"tmp = ref.target.filter(lambda k: k[:-1] in R).normalize().sort()\n",
- "tmp.project(lambda k: k[:-1]).compare(R).sort_values(\"key\")"
+ "tmp.project(lambda k: k[:-1]).compare(R).sort_values('key')"
]
},
{
diff --git a/notes/LM-Fun.ipynb b/notes/LM-Fun.ipynb
index d901a61f..33ba4c51 100644
--- a/notes/LM-Fun.ipynb
+++ b/notes/LM-Fun.ipynb
@@ -22,7 +22,7 @@
"metadata": {},
"outputs": [],
"source": [
- "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+ "tokenizer = AutoTokenizer.from_pretrained('gpt2')"
]
},
{
@@ -50,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
- "lm = LLM(AutoModelForCausalLM.from_pretrained(\"gpt2\"))"
+ "lm = LLM(AutoModelForCausalLM.from_pretrained('gpt2'))"
]
},
{
@@ -98,7 +98,7 @@
"metadata": {},
"outputs": [],
"source": [
- "p = GreedilyTokenizedLLM(\"gpt2\").p_next(\"Once upon a time,\")"
+ "p = GreedilyTokenizedLLM('gpt2').p_next('Once upon a time,')"
]
},
{
@@ -108,8 +108,8 @@
"metadata": {},
"outputs": [],
"source": [
- "p = GreedilyTokenizedLLM(\"gpt2\").p_next(\n",
- " \"The following is some code that implements quick sort in Python:\"\n",
+ "p = GreedilyTokenizedLLM('gpt2').p_next(\n",
+ " 'The following is some code that implements quick sort in Python:'\n",
")"
]
},
@@ -146,7 +146,7 @@
"metadata": {},
"outputs": [],
"source": [
- "M = GreedilyTokenizedLLM(\"gpt2\")\n",
+ "M = GreedilyTokenizedLLM('gpt2')\n",
"# .p_terminal('The following is some code that implements quick sort in Python:')"
]
},
@@ -168,7 +168,7 @@
" return hash(self.xs)\n",
"\n",
" def __repr__(self):\n",
- " return f\"{self.xs}\"\n",
+ " return f'{self.xs}'\n",
"\n",
" def p(self):\n",
" P = M.p_next(self.xs)\n",
@@ -213,7 +213,7 @@
"outputs": [],
"source": [
"Particle(\n",
- " \"The following is some code that implements the quick sort algorithm in Python:\"\n",
+ " 'The following is some code that implements the quick sort algorithm in Python:'\n",
").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()"
]
},
@@ -225,7 +225,7 @@
"outputs": [],
"source": [
"Particle(\n",
- " \"Once upon a time\"\n",
+ " 'Once upon a time'\n",
").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()"
]
},
@@ -244,7 +244,7 @@
"metadata": {},
"outputs": [],
"source": [
- "p = Particle(\"Once upon a time,\")"
+ "p = Particle('Once upon a time,')"
]
},
{
diff --git a/notes/Lark-Interface.ipynb b/notes/Lark-Interface.ipynb
index b627c53e..52129d3a 100644
--- a/notes/Lark-Interface.ipynb
+++ b/notes/Lark-Interface.ipynb
@@ -263,10 +263,12 @@
"metadata": {},
"outputs": [],
"source": [
- "text = \"12 + 24 - 36 * 48 / 60 SELECT table.name AS thing WHERE table.potato IS NOT 'banana'\"\n",
+ "text = (\n",
+ " \"12 + 24 - 36 * 48 / 60 SELECT table.name AS thing WHERE table.potato IS NOT 'banana'\"\n",
+ ")\n",
"\n",
"for x, y in lark_stuff.simple_tokenizer(text):\n",
- " print(f\"{x:15s} -> {y!r}\")"
+ " print(f'{x:15s} -> {y!r}')"
]
},
{
@@ -284,7 +286,7 @@
"metadata": {},
"outputs": [],
"source": [
- "text = \"SELECT name FROM data \""
+ "text = 'SELECT name FROM data '"
]
},
{
@@ -331,7 +333,7 @@
"metadata": {},
"outputs": [],
"source": [
- "lark_stuff.parser.parse(tokens, \"start\")"
+ "lark_stuff.parser.parse(tokens, 'start')"
]
},
{
@@ -423,9 +425,11 @@
"metadata": {},
"outputs": [],
"source": [
- "len(g.cnf.rules), len(g.cnf.prefix_grammar.trim().rules), len(\n",
- " g.cnf.prefix_grammar.trim().rules\n",
- ") / len(g.cnf.rules)"
+ "(\n",
+ " len(g.cnf.rules),\n",
+ " len(g.cnf.prefix_grammar.trim().rules),\n",
+ " len(g.cnf.prefix_grammar.trim().rules) / len(g.cnf.rules),\n",
+ ")"
]
},
{
diff --git a/notes/SegTokenAligner.ipynb b/notes/SegTokenAligner.ipynb
index a758ba12..b87f8939 100644
--- a/notes/SegTokenAligner.ipynb
+++ b/notes/SegTokenAligner.ipynb
@@ -97,7 +97,7 @@
}
],
"source": [
- "prefix = \"aa#a#a#\"\n",
+ "prefix = 'aa#a#a#'\n",
"display(char_lm.p_next(prefix))\n",
"display(pullback(char_lm, prefix))"
]
@@ -1683,7 +1683,7 @@
}
],
"source": [
- "tracer.root.graphviz(fmt_node=lambda x: f\"{x.mass}/{x._mass:.2g}\")"
+ "tracer.root.graphviz(fmt_node=lambda x: f'{x.mass}/{x._mass:.2g}')"
]
},
{
diff --git a/notes/Segmentation-PFST.ipynb b/notes/Segmentation-PFST.ipynb
index 3880f084..bb306f94 100644
--- a/notes/Segmentation-PFST.ipynb
+++ b/notes/Segmentation-PFST.ipynb
@@ -50,8 +50,8 @@
"metadata": {},
"outputs": [],
"source": [
- "contexts = {\"a\", \"b\", \"c\", \"ab\", \"abc\"}\n",
- "alphabet = set(\"abc\")"
+ "contexts = {'a', 'b', 'c', 'ab', 'abc'}\n",
+ "alphabet = set('abc')"
]
},
{
@@ -257,8 +257,8 @@
"outputs": [],
"source": [
"test_strings = [\n",
- " \"aba\",\n",
- " \"ab\",\n",
+ " 'aba',\n",
+ " 'ab',\n",
" #'aa',\n",
" #'acab',\n",
" #'abc',\n",
@@ -478,7 +478,7 @@
"source": [
"C = cc\n",
"for x in test_strings:\n",
- " display(HTML(f\"
{x}
\"))\n",
+ " display(HTML(f'
{x}
'))\n",
" run_segmentation_test(C, x, contexts, verbose=2)"
]
},
@@ -726,7 +726,7 @@
],
"source": [
"for x in test_strings:\n",
- " display(HTML(f\"
{x}
\"))\n",
+ " display(HTML(f'
{x}
'))\n",
" run_segmentation_test(canonical, x, contexts, verbose=1)"
]
},
@@ -968,8 +968,8 @@
}
],
"source": [
- "contexts = {\"a\", \"b\", \"c\", \"abc\"} # not prefix closed!\n",
- "alphabet = set(\"abc\")\n",
+ "contexts = {'a', 'b', 'c', 'abc'} # not prefix closed!\n",
+ "alphabet = set('abc')\n",
"cc = construction(contexts, alphabet, canonical=True)\n",
"cc.graphviz(fmt_node=fmt)"
]
@@ -1023,7 +1023,7 @@
],
"source": [
"for x in test_strings:\n",
- " display(HTML(f\"
{x}
\"))\n",
+ " display(HTML(f'
{x}
'))\n",
" run_segmentation_test(cc, x, contexts, verbose=1)"
]
},
diff --git a/notes/Token-Alignment.ipynb b/notes/Token-Alignment.ipynb
index c2e89ff4..6f9dc21c 100644
--- a/notes/Token-Alignment.ipynb
+++ b/notes/Token-Alignment.ipynb
@@ -77,8 +77,8 @@
")\n",
"\n",
"A = p.cfg.V\n",
- "B = {\"a\", \"aa\", \"aaa\", EOS}\n",
- "ϕ = lambda b: \"\".join(b).strip(EOS)"
+ "B = {'a', 'aa', 'aaa', EOS}\n",
+ "ϕ = lambda b: ''.join(b).strip(EOS)"
]
},
{
@@ -252,7 +252,7 @@
"metadata": {},
"outputs": [],
"source": [
- "T(\"aaa\", None).epsremove.trim"
+ "T('aaa', None).epsremove.trim"
]
},
{
@@ -262,7 +262,7 @@
"metadata": {},
"outputs": [],
"source": [
- "T(\"aaa\", None).total_weight()"
+ "T('aaa', None).total_weight()"
]
},
{
@@ -315,7 +315,7 @@
"source": [
"display_table(\n",
" [[p.cfg.language(100).project(ϕ), generation_tree(graft).D, PL]],\n",
- " headings=[\"target\", \"grafting-heuristic\", \"composition\"],\n",
+ " headings=['target', 'grafting-heuristic', 'composition'],\n",
")"
]
},
@@ -386,9 +386,7 @@
"metadata": {},
"outputs": [],
"source": [
- "L_PB.assert_equal(\n",
- " p.cfg.language(100).project(ϕ)\n",
- ") # character-level distribution matches!"
+ "L_PB.assert_equal(p.cfg.language(100).project(ϕ)) # character-level distribution matches!"
]
},
{
@@ -400,7 +398,7 @@
"source": [
"display_table(\n",
" [[p.cfg.language(100).project(ϕ), L_PB, generation_tree(graft).D, PL]],\n",
- " headings=[\"target\", \"pfst\", \"grafting-heuristic\", \"composition\"],\n",
+ " headings=['target', 'pfst', 'grafting-heuristic', 'composition'],\n",
")"
]
},
diff --git a/notes/grammar_processing_issues.ipynb b/notes/grammar_processing_issues.ipynb
index 7f4afbed..ff6c5fcc 100644
--- a/notes/grammar_processing_issues.ipynb
+++ b/notes/grammar_processing_issues.ipynb
@@ -21,8 +21,8 @@
"import sys\n",
"import getpass\n",
"\n",
- "if getpass.getuser() == \"benjamin.lebrun\":\n",
- " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")"
+ "if getpass.getuser() == 'benjamin.lebrun':\n",
+ " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')"
]
},
{
@@ -54,12 +54,12 @@
}
],
"source": [
- "cfg = LarkStuff(open(\"../benchmark/grammars/sql_case_sensitive.lark\").read()).char_cfg(\n",
- " 0.99, ignore=\"[ ]?\"\n",
+ "cfg = LarkStuff(open('../benchmark/grammars/sql_case_sensitive.lark').read()).char_cfg(\n",
+ " 0.99, ignore='[ ]?'\n",
")\n",
"cfg = locally_normalize(cfg, tol=1e-40, maxiter=np.inf)\n",
"\n",
- "with timeit(\"boolean EarleyCFGLM preprocessing\"):\n",
+ "with timeit('boolean EarleyCFGLM preprocessing'):\n",
" guide = EarleyBoolMaskCFGLM(cfg) # should take forever"
]
},
@@ -126,7 +126,7 @@
"metadata": {},
"outputs": [],
"source": [
- "guide.p_next(\"\")"
+ "guide.p_next('')"
]
}
],
diff --git a/notes/hfppl.ipynb b/notes/hfppl.ipynb
index 9a19b024..17ebdd56 100644
--- a/notes/hfppl.ipynb
+++ b/notes/hfppl.ipynb
@@ -70,7 +70,7 @@
"# LLM = CachedCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\", auth_token=os.environ['HF_AUTH_TOKEN'])\n",
"# LLM = CachedCausalLM.from_pretrained(\"lmsys/vicuna-7b-v1.5\")\n",
"# LLM = CachedCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
- "MODEL_ID = \"codellama/CodeLlama-7b-Instruct-hf\""
+ "MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'"
]
},
{
@@ -176,7 +176,7 @@
"metadata": {},
"outputs": [],
"source": [
- "context = prompt + \" SELECT state_color\""
+ "context = prompt + ' SELECT state_color'"
]
},
{
@@ -201,7 +201,6 @@
"\n",
"\n",
"class GreedilyTokenizedLLM:\n",
- "\n",
" def __init__(self, llm, tokenizer):\n",
" self.tokenizer = tokenizer\n",
" self._model = llm\n",
@@ -348,7 +347,7 @@
"WS: /[ \\n]/\n",
"\n",
"\"\"\"\n",
- " ).char_cfg(0.99, ignore=\"[ ]?\")\n",
+ " ).char_cfg(0.99, ignore='[ ]?')\n",
")"
]
},
@@ -583,7 +582,7 @@
}
],
"source": [
- "print(f\"{sum(len(p.context.tokens) for p in particles)/took:.2f} tokens/sec\")"
+ "print(f'{sum(len(p.context.tokens) for p in particles)/took:.2f} tokens/sec')"
]
},
{
diff --git a/notes/hfppl_benleb.ipynb b/notes/hfppl_benleb.ipynb
index c323b138..f513d280 100644
--- a/notes/hfppl_benleb.ipynb
+++ b/notes/hfppl_benleb.ipynb
@@ -30,10 +30,10 @@
"import os\n",
"import getpass\n",
"\n",
- "if getpass.getuser() == \"benjamin.lebrun\":\n",
- " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")\n",
- " os.environ[\"HF_HOME\"] = os.path.join(os.environ[\"SCRATCH\"], \"hf_cache\")\n",
- " print(\"HF cache set; path updated\")\n",
+ "if getpass.getuser() == 'benjamin.lebrun':\n",
+ " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')\n",
+ " os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')\n",
+ " print('HF cache set; path updated')\n",
"\n",
"import nest_asyncio\n",
"\n",
@@ -102,7 +102,7 @@
"from genparse.lm import AsyncGreedilyTokenizedLLM\n",
"\n",
"genparse_llm = AsyncGreedilyTokenizedLLM.from_name(\n",
- " \"codellama/CodeLlama-7b-Instruct-hf\", batch_size=40\n",
+ " 'codellama/CodeLlama-7b-Instruct-hf', batch_size=40\n",
")"
]
},
@@ -116,9 +116,7 @@
"from genparse.cfglm import EarleyBoolMaskCFGLM\n",
"from genparse.util import LarkStuff\n",
"\n",
- "guide = EarleyBoolMaskCFGLM(\n",
- " LarkStuff(very_restricted_sql).char_cfg(0.99, ignore=\"[ ]?\")\n",
- ")"
+ "guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(0.99, ignore='[ ]?'))"
]
},
{
@@ -377,7 +375,7 @@
"particle_approx = sampler.run_inference(\n",
" prompt=prompt,\n",
" proposal=proposal,\n",
- " method=\"smc-standard\",\n",
+ " method='smc-standard',\n",
" n_particles=5,\n",
" max_tokens=50,\n",
" verbosity=1,\n",
diff --git a/notes/hfppl_for_tim_o.ipynb b/notes/hfppl_for_tim_o.ipynb
index 3ae97628..c7d13faf 100644
--- a/notes/hfppl_for_tim_o.ipynb
+++ b/notes/hfppl_for_tim_o.ipynb
@@ -42,14 +42,12 @@
"import os\n",
"import getpass\n",
"\n",
- "if (\n",
- " getpass.getuser() == \"benjamin.lebrun\"\n",
- "): # change to your user if you want to set these\n",
+ "if getpass.getuser() == 'benjamin.lebrun': # change to your user if you want to set these\n",
" # @TIMO you may need to set this to your local genparse repo\n",
- " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")\n",
+ " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')\n",
" # @TIMO also set your cache IF you run into disk quota issues\n",
- " os.environ[\"HF_HOME\"] = os.path.join(os.environ[\"SCRATCH\"], \"hf_cache\")\n",
- " print(\"HF cache set; path updated\")"
+ " os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')\n",
+ " print('HF cache set; path updated')"
]
},
{
@@ -86,7 +84,7 @@
}
],
"source": [
- "MODEL_ID = \"codellama/CodeLlama-7b-Instruct-hf\"\n",
+ "MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'\n",
"hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" MODEL_ID,\n",
@@ -158,7 +156,7 @@
" WS: /[ ]/\n",
"\n",
" \"\"\"\n",
- " ).char_cfg(0.99, ignore=\"[ ]?\")\n",
+ " ).char_cfg(0.99, ignore='[ ]?')\n",
")"
]
},
@@ -196,12 +194,15 @@
" self.compare_time = compare_time\n",
"\n",
" async def step(self):\n",
- " (token, llm_prob, guide_prob, proposal_prob) = (\n",
- " await self.proposal.sample_next_token(\n",
- " context=\"\".join(self.context),\n",
- " prompt=self.prompt,\n",
- " compare_time=self.compare_time,\n",
- " )\n",
+ " (\n",
+ " token,\n",
+ " llm_prob,\n",
+ " guide_prob,\n",
+ " proposal_prob,\n",
+ " ) = await self.proposal.sample_next_token(\n",
+ " context=''.join(self.context),\n",
+ " prompt=self.prompt,\n",
+ " compare_time=self.compare_time,\n",
" )\n",
" self.context.append(token)\n",
" self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)\n",
@@ -214,7 +215,7 @@
" return\n",
"\n",
" def immutable_properties(self):\n",
- " return [\"llm\", \"prompt\", \"guide\", \"compare_token\"]\n",
+ " return ['llm', 'prompt', 'guide', 'compare_token']\n",
"\n",
" def __repr__(self):\n",
" return f\"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}\""
diff --git a/notes/hfppl_jac.ipynb b/notes/hfppl_jac.ipynb
index ef3d6081..a1d72e9a 100644
--- a/notes/hfppl_jac.ipynb
+++ b/notes/hfppl_jac.ipynb
@@ -103,14 +103,14 @@
"\n",
"if is_cuda_available():\n",
" genparse_llm = AsyncGreedilyTokenizedLLM.from_name(\n",
- " \"codellama/CodeLlama-7b-Instruct-hf\", batch_size=40\n",
+ " 'codellama/CodeLlama-7b-Instruct-hf', batch_size=40\n",
" )\n",
"\n",
"else:\n",
" import transformers\n",
" from genparse.lm import LLM\n",
"\n",
- " MODEL_ID = \"gpt2\"\n",
+ " MODEL_ID = 'gpt2'\n",
"\n",
" MAX_TOKENS = 100\n",
" BATCH_SIZE = 80\n",
@@ -133,9 +133,7 @@
"from genparse.cfglm import EarleyBoolMaskCFGLM\n",
"from genparse.util import LarkStuff\n",
"\n",
- "guide = EarleyBoolMaskCFGLM(\n",
- " LarkStuff(very_restricted_sql).char_cfg(0.99, ignore=\"[ ]?\")\n",
- ")\n",
+ "guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(0.99, ignore='[ ]?'))\n",
"\n",
"from genparse.steer import HFPPLSampler\n",
"\n",
@@ -162,7 +160,7 @@
"particle_approx, record = sampler.run_inference(\n",
" prompt=prompt,\n",
" proposal=proposal,\n",
- " method=\"smc-standard\",\n",
+ " method='smc-standard',\n",
" return_record=True, # use version of smc that keeps a record\n",
" n_particles=12,\n",
" max_tokens=60,\n",
@@ -581,10 +579,10 @@
"READ_EXAMPLE = 0\n",
"\n",
"if WRITE_EXAMPLE:\n",
- " with open(\"hfppl_jac_example_record.pickle\", \"wb\") as h:\n",
+ " with open('hfppl_jac_example_record.pickle', 'wb') as h:\n",
" pickle.dump(record, h, protocol=pickle.HIGHEST_PROTOCOL)\n",
"if READ_EXAMPLE:\n",
- " with open(\"hfppl_jac_example_record.pickle\", \"rb\") as h:\n",
+ " with open('hfppl_jac_example_record.pickle', 'rb') as h:\n",
" record = pickle.load(h)"
]
},
@@ -634,23 +632,23 @@
"color": "rgb(255.0, 0.0, 0.0)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.457921873170436,
11.873882958685238,
- 11.385552490469177,
- 10.840011287523199,
+ 11.385552490469175,
+ 10.8400112875232,
8.105809481441536,
6.851187449489261,
9.827383887592362,
5.318365856516926,
14.561503537917536,
12.287227365169722,
- 12.709088234399145,
- 12.416948732939439,
- 9.190073478275679,
+ 12.709088234399143,
+ 12.41694873293944,
+ 9.19007347827568,
0.35471287475231794,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -744,7 +742,7 @@
"color": "rgb(231.8181818181818, 11.363636363636363, 23.181818181818183)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.457921873170436,
16.823401969323964,
@@ -758,14 +756,14 @@
23.59244301718731,
3.182691823394595,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
14.421468648915504,
- 18.945608596247574,
+ 18.945608596247578,
15.803314668437338,
1.3912279252057005,
0.33226044433526886,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -886,20 +884,20 @@
"color": "rgb(208.63636363636363, 22.727272727272727, 46.36363636363637)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.913449281118025,
16.002365531418288,
15.182974673123722,
15.012323751039531,
- 14.649651227979243,
+ 14.649651227979245,
15.024711875517289,
16.003574076322664,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
14.421468648915504,
- 18.945608596247574,
+ 18.945608596247578,
9.965274409008163,
15.803314668437338,
21.903367784151108,
@@ -1062,7 +1060,7 @@
"color": "rgb(185.45454545454544, 34.09090909090909, 69.54545454545455)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.457921873170436,
16.823401969323964,
@@ -1074,14 +1072,14 @@
4.5267911018928455,
12.520238575555048,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
- 14.419778171926065,
- 10.286253829351077,
+ 14.419778171926064,
+ 10.286253829351075,
15.803314668437338,
0.47058298637143886,
0.47058298637143886,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1180,7 +1178,7 @@
"color": "rgb(162.27272727272725, 45.45454545454545, 92.72727272727273)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
6.250132732861252,
3.7720463352716895,
@@ -1190,14 +1188,14 @@
2.853313585048045,
2.878067976123044,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
- 14.419778171926065,
- 10.286253829351077,
+ 14.419778171926064,
+ 10.286253829351075,
15.803314668437338,
0.47058298637143886,
0.47058298637143886,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1293,27 +1291,27 @@
"color": "rgb(139.09090909090907, 56.81818181818181, 115.90909090909092)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.457921873170436,
11.873882958685238,
- 11.385552490469177,
+ 11.385552490469175,
10.810517375557936,
11.3129664414692,
11.337861545239877,
- 7.8391815325920025,
+ 7.839181532592003,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
- 14.419778171926065,
- 10.286253829351073,
+ 14.419778171926064,
+ 10.286253829351072,
6.981390596950686,
6.981390596950686,
0.6384270822213649,
15.803314668437338,
0.3287918566041185,
0.3287918566041185,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1420,20 +1418,20 @@
"color": "rgb(115.9090909090909, 68.18181818181819, 139.0909090909091)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.913449281118025,
16.002365531418288,
15.182974673123722,
15.012323751039531,
- 14.649651227979243,
+ 14.649651227979245,
15.598038648562746,
15.785040732464894,
14.037910474841624,
- 14.950528730813339,
+ 14.95052873081334,
15.48159936483578,
- 14.419778171926065,
- 10.286253829351073,
+ 14.419778171926064,
+ 10.286253829351072,
0.6384270822213649,
15.88558286313778,
8.725785645125885,
@@ -1442,7 +1440,7 @@
9.70660501864694,
9.70660501864694,
9.70660501864694,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1548,7 +1546,7 @@
"color": "rgb(92.72727272727272, 79.54545454545455, 162.27272727272728)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
8.257423500494252,
0.32435677673491853,
@@ -1557,13 +1555,13 @@
0.00042230151833260696,
0.0004499470989768492,
0.0004538704330765998,
- 14.471657855811605,
+ 14.471657855811603,
8.085033090031478,
8.738079047115395,
- 10.948421124988391,
+ 10.948421124988393,
1.5528481130671423,
8.725785645125885,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1651,12 +1649,12 @@
"color": "rgb(69.54545454545453, 90.9090909090909, 185.45454545454547)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
16.83717465239317,
20.099662725386494,
- 21.443980782304674,
- 18.876277927789282,
+ 21.44398078230467,
+ 18.87627792778928,
17.883132592668552,
17.590938425201635,
17.707023666499882,
@@ -1669,7 +1667,7 @@
15.88558286313778,
0.0477113839686589,
15.802042939131656,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1768,7 +1766,7 @@
"color": "rgb(46.363636363636346, 102.27272727272727, 208.63636363636365)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
16.83717465239317,
1.1335788721113462,
@@ -1788,7 +1786,7 @@
15.802042939131656,
34.23492522202181,
1.3094098924826658,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -1893,27 +1891,27 @@
"color": "rgb(23.18181818181816, 113.63636363636363, 231.81818181818184)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.913449281118025,
17.60968072558521,
- 14.306220277371315,
- 14.646364274891125,
- 14.679009478542559,
+ 14.306220277371317,
+ 14.646364274891123,
+ 14.67900947854256,
15.878930100964638,
15.968444595340683,
20.956889930483197,
0.004795697456429562,
- 14.011882041391393,
+ 14.011882041391392,
14.922601814341911,
15.600288271375703,
- 14.546011331905781,
+ 14.54601133190578,
10.753511485362427,
0.01727711852679148,
15.802042939131656,
- 14.810845859606019,
- 14.810845859606019,
- 12.314263348816201
+ 14.81084585960602,
+ 14.81084585960602,
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -2017,26 +2015,26 @@
"color": "rgb(0.0, 125.0, 255.0)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.913449281118025,
17.60968072558521,
- 18.071068363285406,
+ 18.07106836328541,
19.202786969683103,
20.038008516758698,
17.99283310333559,
17.705088506092654,
12.34104206849142,
20.27856636536855,
- 14.011882041391393,
+ 14.011882041391392,
14.922601814341911,
7.906607755569021,
0.37212405200948695,
0.030751094037019283,
- 15.803735407440861,
- 14.810845859606019,
+ 15.80373540744086,
+ 14.81084585960602,
9.754253198735066,
- 12.314263348816201
+ 12.3142633488162
]
},
"mode": "markers+text",
@@ -6804,18 +6802,18 @@
-0.16685074819284473,
-0.16685074819284473,
-0.16685074819284473,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
-0.5487402987194492,
-0.5487402987194492,
-0.5487402987194492,
@@ -7318,18 +7316,18 @@
-0.16685074819284473,
-0.16685074819284473,
-0.16685074819284473,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
- -0.46702063795957116,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
+ -0.4670206379595712,
-0.5487402987194492,
-0.5487402987194492,
-0.5487402987194492,
@@ -8405,13 +8403,13 @@
"import os\n",
"\n",
"\n",
- "def write_images_scrollby(windowsize=10, outdir=\"figs\", height=800, width=1000):\n",
+ "def write_images_scrollby(windowsize=10, outdir='figs', height=800, width=1000):\n",
" record.plotly2(height=height, width=width).write_image(\n",
- " os.path.join(outdir, f\"EXAMPLE.png\")\n",
+ " os.path.join(outdir, f'EXAMPLE.png')\n",
" )\n",
- " for x_ in range(len(record[\"step\"])):\n",
+ " for x_ in range(len(record['step'])):\n",
" record.plotly2(xrange=[0, x_], height=height, width=width).write_image(\n",
- " os.path.join(outdir, f\"TMP{x_:02d}.png\")\n",
+ " os.path.join(outdir, f'TMP{x_:02d}.png')\n",
" )\n",
"\n",
"\n",
@@ -8483,7 +8481,7 @@
"color": "rgb(255.0, 0.0, 0.0)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
13.760187457708865,
14.025232200647846,
@@ -8495,7 +8493,7 @@
17.844293001325273,
18.670061600207696,
20.71745014148054,
- 21.078602184484026,
+ 21.078602184484023,
18.44380438188891,
24.845980647785268,
29.521086659072292,
@@ -8506,13 +8504,13 @@
12.846940967600313,
0.34818164365681875,
0.34818164365681875,
- 30.489249691855626,
+ 30.489249691855623,
0.2749380041770073,
0.2749380041770073,
0.2749380041770073,
0.2749380041770073,
14.145987065255758,
- 14.158995761951541,
+ 14.15899576195154,
21.597664959206565,
21.597664959206565,
21.597664959206565
@@ -8652,7 +8650,7 @@
"color": "rgb(231.8181818181818, 11.363636363636363, 23.181818181818183)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
14.680732327858076,
@@ -8660,15 +8658,15 @@
15.75612883288789,
15.428455754335374,
15.633213078982251,
- 15.522114616511987,
+ 15.522114616511988,
16.881387097221744,
18.517338460472764,
13.815167021437944,
- 14.738392715010217,
+ 14.738392715010216,
17.44838009523997,
10.803526202530191,
14.145987065255758,
- 14.158995761951541
+ 14.15899576195154
]
},
"mode": "markers+text",
@@ -8760,7 +8758,7 @@
"color": "rgb(208.63636363636363, 22.727272727272727, 46.36363636363637)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
13.760187457708865,
14.025232200647846,
@@ -8772,13 +8770,13 @@
17.844293001325273,
18.670061600207696,
20.71745014148054,
- 21.078602184484026,
+ 21.078602184484023,
18.44380438188891,
10.803526202530197,
12.846940967600313,
4.44016009078639,
14.145987065255758,
- 14.158995761951541,
+ 14.15899576195154,
21.597664959206565
]
},
@@ -8884,23 +8882,23 @@
"color": "rgb(185.45454545454544, 34.09090909090909, 69.54545454545455)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
13.760187457708865,
14.025232200647846,
15.285762544554022,
- 1.3091077604896781,
+ 1.309107760489678,
1.3893459153490075,
1.4814736823266503,
1.4201993902173953,
1.5298910258680254,
1.5377711811412484,
- 1.6419978926288261,
+ 1.641997892628826,
1.7529894072396566,
1.9814612006972587,
10.803526202530191,
14.145987065255758,
- 14.158995761951541,
+ 14.15899576195154,
21.597664959206565,
3.109991767313499,
3.109991767313499,
@@ -9018,23 +9016,23 @@
"color": "rgb(162.27272727272725, 45.45454545454545, 92.72727272727273)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
15.636948613437644,
- 12.683416057537427,
+ 12.683416057537428,
8.940367197023463,
8.29112162695904,
3.213082656992462,
0.02479088685576839,
0.0023517727922828606,
- 7.674656973359717e-05,
- 7.842988502138061e-05,
- 8.515355953244925e-05,
- 4.9729046579742855e-05,
- 6.115419749076286e-08,
+ 0.00007674656973359717,
+ 0.00007842988502138061,
+ 0.00008515355953244925,
+ 0.000049729046579742855,
+ 6.115419749076286e-8,
10.803526202530191,
14.145987065255758,
- 14.158995761951541,
+ 14.15899576195154,
3.109991767313499
]
},
@@ -9127,7 +9125,7 @@
"color": "rgb(139.09090909090907, 56.81818181818181, 115.90909090909092)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
13.760187457708865,
14.025232200647846,
@@ -9233,7 +9231,7 @@
"color": "rgb(115.9090909090909, 68.18181818181819, 139.0909090909091)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
13.760187457708865,
14.025232200647846,
@@ -9245,7 +9243,7 @@
17.844293001325273,
18.670061600207696,
20.71745014148054,
- 21.078602184484026,
+ 21.078602184484023,
18.44380438188891,
17.587349907111197,
10.803526202530191,
@@ -9344,20 +9342,20 @@
"color": "rgb(92.72727272727272, 79.54545454545455, 162.27272727272728)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
14.680732327858076,
- 10.767303527780113,
- 11.362749317687761,
+ 10.767303527780111,
+ 11.36274931768776,
11.08863577821977,
11.852253615725274,
11.548798088756843,
11.286983105524376,
- 11.346569663258295,
+ 11.346569663258297,
12.58895119525174,
12.051466346828228,
- 9.867192326354381,
+ 9.86719232635438,
10.803526202530191,
0.2749380041770073,
17.146513534967667,
@@ -9463,15 +9461,15 @@
"color": "rgb(69.54545454545453, 90.9090909090909, 185.45454545454547)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
13.340755499216302,
12.625070349598664,
12.693353402446863,
12.557499814426562,
- 12.734033603701935,
- 12.882704181997719,
+ 12.734033603701937,
+ 12.88270418199772,
6.901263631086172,
7.492526579923725,
7.5617941623819585,
@@ -9605,7 +9603,7 @@
"color": "rgb(46.363636363636346, 102.27272727272727, 208.63636363636365)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
14.680732327858076,
@@ -9617,12 +9615,12 @@
15.337821505252352,
16.017845068085414,
17.77414156136321,
- 18.390048709064374,
- 21.773526627488554,
- 3.4148849558775534,
+ 18.390048709064377,
+ 21.77352662748855,
+ 3.414884955877554,
17.587349907111193,
2.5699505429676868,
- 3.5948764828277963e-09,
+ 3.5948764828277963e-9,
2.1818119524102917,
17.587349907111193,
8.966869735242355,
@@ -9765,7 +9763,7 @@
"color": "rgb(23.18181818181816, 113.63636363636363, 231.81818181818184)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
14.680732327858076,
@@ -9890,14 +9888,14 @@
"color": "rgb(0.0, 125.0, 255.0)",
"opacity": 0.15,
"size": [
- 14.142135623730951,
+ 14.142135623730953,
14.142135623730953,
14.193731267149062,
14.680732327858076,
14.87843464250492,
15.75612883288789,
15.428455754335374,
- 15.622299546289641,
+ 15.62229954628964,
15.909286726897806,
16.897293160605635,
13.44411451756642,
@@ -17079,12 +17077,12 @@
"WINDOWSIZE = 10\n",
"\n",
"d__ = d_\n",
- "d__[\"context\"].apply(lambda r: \"\".join(r[:-1]))\n",
- "d__[\"overflow_context\"] = d__.apply(\n",
- " lambda r: \"\".join(r[\"context\"][: r[\"step\"] - (WINDOWSIZE + 1)]), axis=1\n",
+ "d__['context'].apply(lambda r: ''.join(r[:-1]))\n",
+ "d__['overflow_context'] = d__.apply(\n",
+ " lambda r: ''.join(r['context'][: r['step'] - (WINDOWSIZE + 1)]), axis=1\n",
").to_list()\n",
- "d__[\"window_context\"] = d__.apply(\n",
- " lambda r: \"\".join(r[\"context\"][r[\"step\"] - WINDOWSIZE : r[\"step\"]]), axis=1\n",
+ "d__['window_context'] = d__.apply(\n",
+ " lambda r: ''.join(r['context'][r['step'] - WINDOWSIZE : r['step']]), axis=1\n",
").to_list()\n",
"d__"
]
@@ -17269,7 +17267,7 @@
"size": [
14.142135623730953,
15.636948613437644,
- 12.683416057537427
+ 12.683416057537428
]
},
"mode": "markers+text",
@@ -19371,7 +19369,7 @@
],
[
" SELECT",
- -12.044136238968767,
+ -12.044136238968768,
false
],
[
@@ -19386,7 +19384,7 @@
],
[
" SELECT age FROM data",
- -12.890518994730765,
+ -12.890518994730764,
false
],
[
@@ -19457,13 +19455,13 @@
1,
0.9999999999999996,
0.08345997096712078,
- 3.144573361448564e-05,
- 2.4956886495467274e-05,
- 2.1462501915942082e-05,
- 1.6268849815188344e-05,
- 1.4757300352087234e-05,
- 1.1698653730016005e-05,
- 1.4753454812917209e-05,
+ 0.00003144573361448564,
+ 0.00002495688649546727,
+ 0.00002146250191594208,
+ 0.000016268849815188344,
+ 0.000014757300352087234,
+ 0.000011698653730016005,
+ 0.000014753454812917207,
0.0017675764341983872,
0.3297392910493373,
0.43085236169137026,
@@ -19567,7 +19565,7 @@
],
[
" SELECT vote, zip",
- -0.9403565687817279,
+ -0.940356568781728,
false
],
[
@@ -19577,7 +19575,7 @@
],
[
" SELECT vote, zipcode FROM",
- -1.1918328268726919,
+ -1.191832826872692,
false
],
[
@@ -19722,7 +19720,7 @@
1.165703252685702,
1.341546047879836,
1.658581020724092,
- 1.8327405829377286,
+ 1.8327405829377288,
1.9151469672615904,
1.776976828408179,
1.4601356587730292,
@@ -19734,7 +19732,7 @@
0.000722149296949957,
0.005201221340267852,
1,
- 3.6494517184788635,
+ 3.649451718478863,
0.0004472917528217814,
9.995974374224588,
0.0004472917528217814,
@@ -19888,7 +19886,7 @@
],
[
" SELECT vote, vote, zip",
- -1.6492045188518663,
+ -1.6492045188518665,
false
],
[
@@ -19996,16 +19994,16 @@
1.0000000000000002,
1.165703252685702,
1.341546047879836,
- 1.1054502310927365,
+ 1.1054502310927363,
0.9546201682781362,
1.0306565283252629,
1.124728662169629,
- 0.9528559020218703,
+ 0.9528559020218704,
1.14477607748672,
1.717334448386174,
2.816856935026382,
0.7772670745842045,
- 2.349754756014147e-09,
+ 2.349754756014147e-9,
0.89473665260685,
1.0842783427146367,
0.7606984589293397,
@@ -20129,7 +20127,7 @@
],
[
" SELECT age FROM data",
- -1.4672590691116039,
+ -1.467259069111604,
false
],
[
@@ -20239,7 +20237,7 @@
1.0821068550913535,
1.0732693969922604,
1.052737328989911,
- 0.9468694342361799,
+ 0.94686943423618,
0.3440935907206618,
0.5734037698910691,
1.0842783427146367,
@@ -20355,7 +20353,7 @@
],
[
" SELECT vote FROM",
- -1.8979466102495361,
+ -1.897946610249536,
false
],
[
@@ -20470,9 +20468,9 @@
0.5642008988781192,
0.5268289341512757,
0.5341759231035361,
- 0.49153354750180955,
- 0.22901753434443822,
- 1.0996645104191986e-06,
+ 0.4915335475018095,
+ 0.22901753434443825,
+ 1.0996645104191986e-6,
1.0842783427146367,
0.7606984589293397,
1.461192198434121,
@@ -20690,7 +20688,7 @@
1,
1.0000000000000002,
1.2403163373321546,
- 1.4698766961409617,
+ 1.4698766961409615,
1.5713645707446335,
1.6673103705955654,
1.6201204504974391,
@@ -20802,7 +20800,7 @@
],
[
" SELECT state",
- -3.8397796375923234,
+ -3.839779637592323,
false
],
[
@@ -20847,7 +20845,7 @@
],
[
" SELECT age, gender FROM data ORDER BY age ASC ",
- -0.9458901196953537,
+ -0.9458901196953536,
false
],
[
@@ -21038,7 +21036,7 @@
],
[
" SELECT age, gender FROM data ORDER BY age ASC ",
- -1.6366235684784727,
+ -1.636623568478473,
false
],
[
@@ -21430,7 +21428,7 @@
],
[
" SELECT age, gender FROM data ORDER BY age ASC Context: {row['context_string']}
Weight: {row['weight']}
Resample?: {row['resample?']}\",\n",
" axis=1,\n",
" ),\n",
- " textposition=\"top center\",\n",
+ " textposition='top center',\n",
")\n",
"\n",
"fig = go.Figure(data=[scatter])\n",
"\n",
"# Define frames\n",
"frames = []\n",
- "steps = sorted(recs[\"step\"].unique())\n",
+ "steps = sorted(recs['step'].unique())\n",
"for i, step in enumerate(steps):\n",
" # Filter data up to the current step\n",
- " data_up_to_step = recs[recs[\"step\"] <= step]\n",
+ " data_up_to_step = recs[recs['step'] <= step]\n",
"\n",
" scatter = go.Scatter(\n",
- " x=data_up_to_step[\"step\"],\n",
- " y=data_up_to_step[\"particle\"],\n",
- " mode=\"markers+text\",\n",
+ " x=data_up_to_step['step'],\n",
+ " y=data_up_to_step['particle'],\n",
+ " mode='markers+text',\n",
" marker=dict(\n",
- " size=data_up_to_step[\"prop_exp_weight\"],\n",
- " color=data_up_to_step[\"resampled as\"].map(color_map),\n",
+ " size=data_up_to_step['prop_exp_weight'],\n",
+ " color=data_up_to_step['resampled as'].map(color_map),\n",
" opacity=0.6,\n",
" ),\n",
- " text=data_up_to_step[\"token\"],\n",
- " hoverinfo=\"text\",\n",
+ " text=data_up_to_step['token'],\n",
+ " hoverinfo='text',\n",
" hovertext=data_up_to_step.apply(\n",
" lambda row: f\"Token: {row['token']}
Context: {row['context_string']}
Weight: {row['weight']}
Resample?: {row['resample?']}\",\n",
" axis=1,\n",
" ),\n",
- " textposition=\"top center\",\n",
+ " textposition='top center',\n",
" )\n",
"\n",
" resampling_lines = []\n",
- " for resampled_as in recs[\"resampled as\"].unique():\n",
- " resampled_as_data = recs[recs[\"resampled as\"] == resampled_as]\n",
+ " for resampled_as in recs['resampled as'].unique():\n",
+ " resampled_as_data = recs[recs['resampled as'] == resampled_as]\n",
"\n",
" # Add resampling lines for the current step\n",
- " for _, row in resampled_as_data[resampled_as_data[\"step\"] == step].iterrows():\n",
+ " for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():\n",
" resampling_lines.append(\n",
" go.Scatter(\n",
- " x=[row[\"step\"], row[\"step\"] + 1],\n",
- " y=[row[\"resampled as\"], row[\"particle\"]],\n",
- " mode=\"lines\",\n",
+ " x=[row['step'], row['step'] + 1],\n",
+ " y=[row['resampled as'], row['particle']],\n",
+ " mode='lines',\n",
" line=dict(color=color_map[resampled_as]),\n",
" opacity=0.3,\n",
" name=resampled_as,\n",
@@ -221311,13 +221311,13 @@
" )\n",
"\n",
" for s in steps[: i + 1]:\n",
- " for _, row in resampled_as_data[resampled_as_data[\"step\"] == s].iterrows():\n",
+ " for _, row in resampled_as_data[resampled_as_data['step'] == s].iterrows():\n",
" resampling_lines.append(\n",
" go.Scatter(\n",
- " x=[row[\"step\"], row[\"step\"] + 1],\n",
- " y=[row[\"particle\"], row[\"particle\"]],\n",
- " mode=\"lines\",\n",
- " line=dict(color=\"gray\"),\n",
+ " x=[row['step'], row['step'] + 1],\n",
+ " y=[row['particle'], row['particle']],\n",
+ " mode='lines',\n",
+ " line=dict(color='gray'),\n",
" opacity=0.2,\n",
" name=resampled_as,\n",
" showlegend=False,\n",
@@ -221329,12 +221329,12 @@
" if step in resample_steps:\n",
" vertical_lines.append(\n",
" go.layout.Shape(\n",
- " type=\"line\",\n",
+ " type='line',\n",
" x0=step,\n",
" x1=step,\n",
" y0=0,\n",
" y1=1,\n",
- " line=dict(width=4, color=\"gray\"),\n",
+ " line=dict(width=4, color='gray'),\n",
" opacity=0.15,\n",
" )\n",
" )\n",
@@ -221351,74 +221351,74 @@
" width=1200,\n",
" height=500,\n",
" xaxis=dict(range=[-1, 36]),\n",
- " yaxis=dict(type=\"category\"), # Ensure the y-axis is categorical\n",
- " plot_bgcolor=\"#fff\",\n",
+ " yaxis=dict(type='category'), # Ensure the y-axis is categorical\n",
+ " plot_bgcolor='#fff',\n",
" showlegend=False,\n",
" updatemenus=[\n",
" {\n",
- " \"buttons\": [\n",
+ " 'buttons': [\n",
" {\n",
- " \"args\": [\n",
+ " 'args': [\n",
" None,\n",
" {\n",
- " \"frame\": {\"duration\": 500, \"redraw\": True},\n",
- " \"fromcurrent\": True,\n",
+ " 'frame': {'duration': 500, 'redraw': True},\n",
+ " 'fromcurrent': True,\n",
" },\n",
" ],\n",
- " \"label\": \"Play\",\n",
- " \"method\": \"animate\",\n",
+ " 'label': 'Play',\n",
+ " 'method': 'animate',\n",
" },\n",
" {\n",
- " \"args\": [\n",
+ " 'args': [\n",
" [None],\n",
" {\n",
- " \"frame\": {\"duration\": 0, \"redraw\": True},\n",
- " \"mode\": \"immediate\",\n",
- " \"transition\": {\"duration\": 0},\n",
+ " 'frame': {'duration': 0, 'redraw': True},\n",
+ " 'mode': 'immediate',\n",
+ " 'transition': {'duration': 0},\n",
" },\n",
" ],\n",
- " \"label\": \"Pause\",\n",
- " \"method\": \"animate\",\n",
+ " 'label': 'Pause',\n",
+ " 'method': 'animate',\n",
" },\n",
" ],\n",
- " \"direction\": \"left\",\n",
- " \"pad\": {\"r\": 10, \"t\": 87},\n",
- " \"showactive\": False,\n",
- " \"type\": \"buttons\",\n",
- " \"x\": 0.1,\n",
- " \"xanchor\": \"right\",\n",
- " \"y\": 0,\n",
- " \"yanchor\": \"top\",\n",
+ " 'direction': 'left',\n",
+ " 'pad': {'r': 10, 't': 87},\n",
+ " 'showactive': False,\n",
+ " 'type': 'buttons',\n",
+ " 'x': 0.1,\n",
+ " 'xanchor': 'right',\n",
+ " 'y': 0,\n",
+ " 'yanchor': 'top',\n",
" }\n",
" ],\n",
" sliders=[\n",
" {\n",
- " \"steps\": [\n",
+ " 'steps': [\n",
" {\n",
- " \"args\": [\n",
+ " 'args': [\n",
" [str(step)],\n",
" {\n",
- " \"frame\": {\"duration\": 300, \"redraw\": True},\n",
- " \"mode\": \"immediate\",\n",
- " \"transition\": {\"duration\": 300},\n",
+ " 'frame': {'duration': 300, 'redraw': True},\n",
+ " 'mode': 'immediate',\n",
+ " 'transition': {'duration': 300},\n",
" },\n",
" ],\n",
- " \"label\": str(step),\n",
- " \"method\": \"animate\",\n",
+ " 'label': str(step),\n",
+ " 'method': 'animate',\n",
" }\n",
" for step in steps\n",
" ],\n",
- " \"x\": 0.1,\n",
- " \"xanchor\": \"left\",\n",
- " \"y\": 0,\n",
- " \"yanchor\": \"top\",\n",
- " \"currentvalue\": {\n",
- " \"font\": {\"size\": 20},\n",
- " \"prefix\": \"Step:\",\n",
- " \"visible\": True,\n",
- " \"xanchor\": \"right\",\n",
+ " 'x': 0.1,\n",
+ " 'xanchor': 'left',\n",
+ " 'y': 0,\n",
+ " 'yanchor': 'top',\n",
+ " 'currentvalue': {\n",
+ " 'font': {'size': 20},\n",
+ " 'prefix': 'Step:',\n",
+ " 'visible': True,\n",
+ " 'xanchor': 'right',\n",
" },\n",
- " \"transition\": {\"duration\": 300, \"easing\": \"cubic-in-out\"},\n",
+ " 'transition': {'duration': 300, 'easing': 'cubic-in-out'},\n",
" }\n",
" ],\n",
")\n",
diff --git a/notes/sql_debug.ipynb b/notes/sql_debug.ipynb
index 647a7ec6..dcb0112b 100644
--- a/notes/sql_debug.ipynb
+++ b/notes/sql_debug.ipynb
@@ -21,8 +21,8 @@
"import sys\n",
"import getpass\n",
"\n",
- "if getpass.getuser() == \"benjamin.lebrun\":\n",
- " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")"
+ "if getpass.getuser() == 'benjamin.lebrun':\n",
+ " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')"
]
},
{
@@ -45,7 +45,7 @@
"source": [
"from genparse.evaluation.dataset import Dataset\n",
"\n",
- "dataset = Dataset(\"spider\", \"validation\")"
+ "dataset = Dataset('spider', 'validation')"
]
},
{
@@ -55,9 +55,9 @@
"metadata": {},
"outputs": [],
"source": [
- "cfg = LarkStuff(\n",
- " open(\"../benchmark/grammars/sql_case_insensitive.lark\").read()\n",
- ").char_cfg(0.99, ignore=\"[ ]?\")\n",
+ "cfg = LarkStuff(open('../benchmark/grammars/sql_case_insensitive.lark').read()).char_cfg(\n",
+ " 0.99, ignore='[ ]?'\n",
+ ")\n",
"guide = EarleyBoolMaskCFGLM(cfg)"
]
},