From 491137e0a1b9f4e6099db914915cd6543288949b Mon Sep 17 00:00:00 2001 From: benlipkin Date: Sat, 15 Jun 2024 11:13:02 -0400 Subject: [PATCH] format ipynb --- notes/A-Tale-of-Two-Transducers.ipynb | 16 +- notes/Character-at-a-Time.ipynb | 20 +- notes/FST.ipynb | 50 +- notes/Inference-Playground.ipynb | 10 +- notes/LM-Fun.ipynb | 20 +- notes/Lark-Interface.ipynb | 18 +- notes/SegTokenAligner.ipynb | 4 +- notes/Segmentation-PFST.ipynb | 18 +- notes/Token-Alignment.ipynb | 16 +- notes/grammar_processing_issues.ipynb | 12 +- notes/hfppl.ipynb | 9 +- notes/hfppl_benleb.ipynb | 16 +- notes/hfppl_for_tim_o.ipynb | 31 +- notes/hfppl_jac.ipynb | 3822 ++++++++++++------------- notes/sql_debug.ipynb | 12 +- 15 files changed, 2036 insertions(+), 2038 deletions(-) diff --git a/notes/A-Tale-of-Two-Transducers.ipynb b/notes/A-Tale-of-Two-Transducers.ipynb index 2cd62071..c10f9ba7 100644 --- a/notes/A-Tale-of-Two-Transducers.ipynb +++ b/notes/A-Tale-of-Two-Transducers.ipynb @@ -192,7 +192,7 @@ } ], "source": [ - "D = Delta(\"ab\")\n", + "D = Delta('ab')\n", "D" ] }, @@ -331,7 +331,7 @@ } ], "source": [ - "(D @ WFSA.from_string(\"ab\", Float)).trim" + "(D @ WFSA.from_string('ab', Float)).trim" ] }, { @@ -465,7 +465,7 @@ } ], "source": [ - "(cfg @ Delta1(\"ab\")).nullaryremove(binarize=False).trim()" + "(cfg @ Delta1('ab')).nullaryremove(binarize=False).trim()" ] }, { @@ -664,7 +664,7 @@ } ], "source": [ - "derivative(\"ab\", \"ab\").graphviz(fmt_node=lambda x: str(x))" + "derivative('ab', 'ab').graphviz(fmt_node=lambda x: str(x))" ] }, { @@ -688,7 +688,7 @@ } ], "source": [ - "(cfg @ derivative(\"a\", \"ab\") @ derivative(\"a\", \"ab\")).language(10)" + "(cfg @ derivative('a', 'ab') @ derivative('a', 'ab')).language(10)" ] }, { @@ -712,7 +712,7 @@ } ], "source": [ - "(cfg @ derivative(\"aa\", \"ab\")).language(10)" + "(cfg @ derivative('aa', 'ab')).language(10)" ] }, { @@ -748,7 +748,7 @@ } ], "source": [ - "cfg.derivative(\"a\").trim()" + "cfg.derivative('a').trim()" ] }, { @@ -778,7 +778,7 @@ } ], "source": [ - "(cfg @ derivative(\"a\", cfg.V)).nullaryremove(binarize=False).unaryremove().trim()" + "(cfg @ derivative('a', cfg.V)).nullaryremove(binarize=False).unaryremove().trim()" ] }, { diff --git a/notes/Character-at-a-Time.ipynb b/notes/Character-at-a-Time.ipynb index e0a8b4b8..0be025d1 100644 --- a/notes/Character-at-a-Time.ipynb +++ b/notes/Character-at-a-Time.ipynb @@ -33,7 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = GreedilyTokenizedLLM(\"gpt2\")" + "llm = GreedilyTokenizedLLM('gpt2')" ] }, { @@ -57,7 +57,7 @@ } ], "source": [ - "prompt = \"Hello my name is\"\n", + "prompt = 'Hello my name is'\n", "pp = llm.p_next(prompt, 10).normalize()\n", "pp" ] @@ -174,7 +174,7 @@ } ], "source": [ - "print(repr(\"\".join(pcfg.sample())))" + "print(repr(''.join(pcfg.sample())))" ] }, { @@ -206,7 +206,7 @@ "tracer = TraceSWOR()\n", "for _ in range(1):\n", " with tracer:\n", - " print(\"----------------------------------\")\n", + " print('----------------------------------')\n", " ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer, verbosity=1)\n", " print(ys)" ] @@ -1486,7 +1486,7 @@ ], "source": [ "tracer.root.graphviz(\n", - " fmt_node=lambda x: f\"{x._mass:.3g}\", fmt_edge=lambda i, a, j: repr(a)\n", + " fmt_node=lambda x: f'{x._mass:.3g}', fmt_edge=lambda i, a, j: repr(a)\n", ")" ] }, @@ -1524,7 +1524,7 @@ "ADJ: \"fruit\"\n", "\n", "\"\"\"\n", - " ).char_cfg(0.99, ignore=\"[ ]?\"),\n", + " ).char_cfg(0.99, ignore='[ ]?'),\n", " tol=1e-100,\n", " )\n", ")" @@ -1548,7 +1548,7 @@ } ], "source": [ - "\"\".join(fruit.sample())" + "''.join(fruit.sample())" ] }, { @@ -1579,7 +1579,7 @@ } ], "source": [ - "fruit(\"fruit flies like a banana \" + EOS)" + "fruit('fruit flies like a banana ' + EOS)" ] }, { @@ -1599,12 +1599,12 @@ } ], "source": [ - "prompt = \"The following is a favorite sentence among linguists:\"\n", + "prompt = 'The following is a favorite sentence among linguists:'\n", "token_trie_approx = TokenTrieApproximation(llm, fruit)\n", "tracer = TraceSWOR()\n", "for _ in range(1):\n", " with tracer:\n", - " print(\"----------------------------------\")\n", + " print('----------------------------------')\n", " ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer)\n", " print(ys)" ] diff --git a/notes/FST.ipynb b/notes/FST.ipynb index 322cc0db..2dcaaa18 100644 --- a/notes/FST.ipynb +++ b/notes/FST.ipynb @@ -155,7 +155,7 @@ "metadata": {}, "outputs": [], "source": [ - "b2c = H.fst.prune_to_alphabet(None, foo.V | {\"\"}).renumber" + "b2c = H.fst.prune_to_alphabet(None, foo.V | {''}).renumber" ] }, { @@ -174,10 +174,10 @@ "outputs": [], "source": [ "for x in foo.cnf.language(3):\n", - " display(HTML(\"
\"))\n", + " display(HTML('
'))\n", " print(x)\n", " bpe_x = b2c(None, x).epsremove.trim\n", - " print(\"total weight of BPE sequences (i.e., ambiguity):\", bpe_x.total_weight())\n", + " print('total weight of BPE sequences (i.e., ambiguity):', bpe_x.total_weight())\n", " display(bpe_x)\n", " print()" ] @@ -253,7 +253,7 @@ "for x, w in foo.cnf.language(L + 2).items():\n", " if len(x) > L:\n", " continue\n", - " cc[\"\".join(x)] += w\n", + " cc[''.join(x)] += w\n", "# cc" ] }, @@ -426,7 +426,7 @@ "df = []\n", "for x, w in sorted(normalize(lm2.p_next(context)).items(), key=lambda kv: -kv[1]):\n", " df.append((x, (H.tokenizer.decode([x]) if x != EOS else EOS), w))\n", - "pd.DataFrame(df, columns=[\"token_id\", \"chars\", \"prob\"]).set_index(\"token_id\")" + "pd.DataFrame(df, columns=['token_id', 'chars', 'prob']).set_index('token_id')" ] }, { @@ -546,9 +546,9 @@ "source": [ "trace = TraceSWOR()\n", "for _ in range(15):\n", - " print(\"mass=\", trace.root.mass)\n", + " print('mass=', trace.root.mass)\n", " with trace:\n", - " print(\"\".join(lm.sample(draw=trace)))" + " print(''.join(lm.sample(draw=trace)))" ] }, { @@ -576,7 +576,7 @@ "metadata": {}, "outputs": [], "source": [ - "c2t = lark_stuff.transducer(ignore=\"\", decay=0.0125)\n", + "c2t = lark_stuff.transducer(ignore='', decay=0.0125)\n", "len(c2t.states)" ] }, @@ -595,7 +595,7 @@ "metadata": {}, "outputs": [], "source": [ - "x = \"SELECT * FROM data\"" + "x = 'SELECT * FROM data'" ] }, { @@ -711,7 +711,7 @@ "metadata": {}, "outputs": [], "source": [ - "cfg_t(\"SELECT * FROM data \")" + "cfg_t('SELECT * FROM data ')" ] }, { @@ -721,7 +721,7 @@ "metadata": {}, "outputs": [], "source": [ - "cfg_t(\"SELECT * FROM data \")" + "cfg_t('SELECT * FROM data ')" ] }, { @@ -742,7 +742,7 @@ "outputs": [], "source": [ "for _ in range(10):\n", - " print(\"\".join(lm.sample()))" + " print(''.join(lm.sample()))" ] }, { @@ -752,7 +752,7 @@ "metadata": {}, "outputs": [], "source": [ - "lm.p_next(\"SELECT * FROM \")" + "lm.p_next('SELECT * FROM ')" ] }, { @@ -800,7 +800,7 @@ "metadata": {}, "outputs": [], "source": [ - "x = \"SELECT * FROM data\"\n", + "x = 'SELECT * FROM data'\n", "b = tokenizer.encode(x)\n", "b" ] @@ -822,7 +822,7 @@ "metadata": {}, "outputs": [], "source": [ - "with timeit(\"composition\"):\n", + "with timeit('composition'):\n", " c = FST.from_string(tuple(b), Float) @ b2c\n", "about(c)" ] @@ -879,7 +879,7 @@ "metadata": {}, "outputs": [], "source": [ - "x = x = \"SELECT * FROM data\"" + "x = x = 'SELECT * FROM data'" ] }, { @@ -889,9 +889,9 @@ "metadata": {}, "outputs": [], "source": [ - "with timeit(\"composition\"):\n", + "with timeit('composition'):\n", " bs = b2c @ FST.from_string(x, Float)\n", - "with timeit(\"trim\"):\n", + "with timeit('trim'):\n", " bs.trim\n", "about(bs)" ] @@ -1025,7 +1025,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\".join(lm.sample()))" + "print(''.join(lm.sample()))" ] }, { @@ -1053,9 +1053,7 @@ "metadata": {}, "outputs": [], "source": [ - "bpe_lm = CharAlignedCFGLM(\n", - " lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token\n", - ")" + "bpe_lm = CharAlignedCFGLM(lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token)" ] }, { @@ -1065,7 +1063,7 @@ "metadata": {}, "outputs": [], "source": [ - "lm.p_next(\"\")" + "lm.p_next('')" ] }, { @@ -1083,7 +1081,7 @@ "metadata": {}, "outputs": [], "source": [ - "bpe_lm.p_next(\"\")" + "bpe_lm.p_next('')" ] }, { @@ -1093,7 +1091,7 @@ "metadata": {}, "outputs": [], "source": [ - "lm.p_next(\"SELECT \")" + "lm.p_next('SELECT ')" ] }, { @@ -1103,7 +1101,7 @@ "metadata": {}, "outputs": [], "source": [ - "bpe_lm.p_next(\"SELECT \")" + "bpe_lm.p_next('SELECT ')" ] }, { diff --git a/notes/Inference-Playground.ipynb b/notes/Inference-Playground.ipynb index 48088b81..c93e0c13 100644 --- a/notes/Inference-Playground.ipynb +++ b/notes/Inference-Playground.ipynb @@ -142,7 +142,7 @@ ], "source": [ "ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)\n", - "ref.target.project(\"\".join)" + "ref.target.project(''.join)" ] }, { @@ -165,7 +165,7 @@ " lm2,\n", " MAX_LENGTH=MAX_LENGTH,\n", " n_particles=N_PARTICLES,\n", - " METHOD=\"is\",\n", + " METHOD='is',\n", " # METHOD = 'smc',\n", ")" ] @@ -258,8 +258,8 @@ } ], "source": [ - "w.project(\"\".join).trim().compare(ref.target.project(\"\".join).trim()).sort_values(\n", - " \"key\", ascending=False\n", + "w.project(''.join).trim().compare(ref.target.project(''.join).trim()).sort_values(\n", + " 'key', ascending=False\n", ")" ] }, @@ -328,7 +328,7 @@ "# truncate the reference distribution to the support set of the sample;\n", "# renamed the keys to handle the minor discrepancy in the EOS symbol\n", "tmp = ref.target.filter(lambda k: k[:-1] in R).normalize().sort()\n", - "tmp.project(lambda k: k[:-1]).compare(R).sort_values(\"key\")" + "tmp.project(lambda k: k[:-1]).compare(R).sort_values('key')" ] }, { diff --git a/notes/LM-Fun.ipynb b/notes/LM-Fun.ipynb index d901a61f..33ba4c51 100644 --- a/notes/LM-Fun.ipynb +++ b/notes/LM-Fun.ipynb @@ -22,7 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")" + "tokenizer = AutoTokenizer.from_pretrained('gpt2')" ] }, { @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "lm = LLM(AutoModelForCausalLM.from_pretrained(\"gpt2\"))" + "lm = LLM(AutoModelForCausalLM.from_pretrained('gpt2'))" ] }, { @@ -98,7 +98,7 @@ "metadata": {}, "outputs": [], "source": [ - "p = GreedilyTokenizedLLM(\"gpt2\").p_next(\"Once upon a time,\")" + "p = GreedilyTokenizedLLM('gpt2').p_next('Once upon a time,')" ] }, { @@ -108,8 +108,8 @@ "metadata": {}, "outputs": [], "source": [ - "p = GreedilyTokenizedLLM(\"gpt2\").p_next(\n", - " \"The following is some code that implements quick sort in Python:\"\n", + "p = GreedilyTokenizedLLM('gpt2').p_next(\n", + " 'The following is some code that implements quick sort in Python:'\n", ")" ] }, @@ -146,7 +146,7 @@ "metadata": {}, "outputs": [], "source": [ - "M = GreedilyTokenizedLLM(\"gpt2\")\n", + "M = GreedilyTokenizedLLM('gpt2')\n", "# .p_terminal('The following is some code that implements quick sort in Python:')" ] }, @@ -168,7 +168,7 @@ " return hash(self.xs)\n", "\n", " def __repr__(self):\n", - " return f\"{self.xs}\"\n", + " return f'{self.xs}'\n", "\n", " def p(self):\n", " P = M.p_next(self.xs)\n", @@ -213,7 +213,7 @@ "outputs": [], "source": [ "Particle(\n", - " \"The following is some code that implements the quick sort algorithm in Python:\"\n", + " 'The following is some code that implements the quick sort algorithm in Python:'\n", ").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()" ] }, @@ -225,7 +225,7 @@ "outputs": [], "source": [ "Particle(\n", - " \"Once upon a time\"\n", + " 'Once upon a time'\n", ").p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample().p().sample()" ] }, @@ -244,7 +244,7 @@ "metadata": {}, "outputs": [], "source": [ - "p = Particle(\"Once upon a time,\")" + "p = Particle('Once upon a time,')" ] }, { diff --git a/notes/Lark-Interface.ipynb b/notes/Lark-Interface.ipynb index b627c53e..52129d3a 100644 --- a/notes/Lark-Interface.ipynb +++ b/notes/Lark-Interface.ipynb @@ -263,10 +263,12 @@ "metadata": {}, "outputs": [], "source": [ - "text = \"12 + 24 - 36 * 48 / 60 SELECT table.name AS thing WHERE table.potato IS NOT 'banana'\"\n", + "text = (\n", + " \"12 + 24 - 36 * 48 / 60 SELECT table.name AS thing WHERE table.potato IS NOT 'banana'\"\n", + ")\n", "\n", "for x, y in lark_stuff.simple_tokenizer(text):\n", - " print(f\"{x:15s} -> {y!r}\")" + " print(f'{x:15s} -> {y!r}')" ] }, { @@ -284,7 +286,7 @@ "metadata": {}, "outputs": [], "source": [ - "text = \"SELECT name FROM data \"" + "text = 'SELECT name FROM data '" ] }, { @@ -331,7 +333,7 @@ "metadata": {}, "outputs": [], "source": [ - "lark_stuff.parser.parse(tokens, \"start\")" + "lark_stuff.parser.parse(tokens, 'start')" ] }, { @@ -423,9 +425,11 @@ "metadata": {}, "outputs": [], "source": [ - "len(g.cnf.rules), len(g.cnf.prefix_grammar.trim().rules), len(\n", - " g.cnf.prefix_grammar.trim().rules\n", - ") / len(g.cnf.rules)" + "(\n", + " len(g.cnf.rules),\n", + " len(g.cnf.prefix_grammar.trim().rules),\n", + " len(g.cnf.prefix_grammar.trim().rules) / len(g.cnf.rules),\n", + ")" ] }, { diff --git a/notes/SegTokenAligner.ipynb b/notes/SegTokenAligner.ipynb index a758ba12..b87f8939 100644 --- a/notes/SegTokenAligner.ipynb +++ b/notes/SegTokenAligner.ipynb @@ -97,7 +97,7 @@ } ], "source": [ - "prefix = \"aa#a#a#\"\n", + "prefix = 'aa#a#a#'\n", "display(char_lm.p_next(prefix))\n", "display(pullback(char_lm, prefix))" ] @@ -1683,7 +1683,7 @@ } ], "source": [ - "tracer.root.graphviz(fmt_node=lambda x: f\"{x.mass}/{x._mass:.2g}\")" + "tracer.root.graphviz(fmt_node=lambda x: f'{x.mass}/{x._mass:.2g}')" ] }, { diff --git a/notes/Segmentation-PFST.ipynb b/notes/Segmentation-PFST.ipynb index 3880f084..bb306f94 100644 --- a/notes/Segmentation-PFST.ipynb +++ b/notes/Segmentation-PFST.ipynb @@ -50,8 +50,8 @@ "metadata": {}, "outputs": [], "source": [ - "contexts = {\"a\", \"b\", \"c\", \"ab\", \"abc\"}\n", - "alphabet = set(\"abc\")" + "contexts = {'a', 'b', 'c', 'ab', 'abc'}\n", + "alphabet = set('abc')" ] }, { @@ -257,8 +257,8 @@ "outputs": [], "source": [ "test_strings = [\n", - " \"aba\",\n", - " \"ab\",\n", + " 'aba',\n", + " 'ab',\n", " #'aa',\n", " #'acab',\n", " #'abc',\n", @@ -478,7 +478,7 @@ "source": [ "C = cc\n", "for x in test_strings:\n", - " display(HTML(f\"

{x}

\"))\n", + " display(HTML(f'

{x}

'))\n", " run_segmentation_test(C, x, contexts, verbose=2)" ] }, @@ -726,7 +726,7 @@ ], "source": [ "for x in test_strings:\n", - " display(HTML(f\"

{x}

\"))\n", + " display(HTML(f'

{x}

'))\n", " run_segmentation_test(canonical, x, contexts, verbose=1)" ] }, @@ -968,8 +968,8 @@ } ], "source": [ - "contexts = {\"a\", \"b\", \"c\", \"abc\"} # not prefix closed!\n", - "alphabet = set(\"abc\")\n", + "contexts = {'a', 'b', 'c', 'abc'} # not prefix closed!\n", + "alphabet = set('abc')\n", "cc = construction(contexts, alphabet, canonical=True)\n", "cc.graphviz(fmt_node=fmt)" ] @@ -1023,7 +1023,7 @@ ], "source": [ "for x in test_strings:\n", - " display(HTML(f\"

{x}

\"))\n", + " display(HTML(f'

{x}

'))\n", " run_segmentation_test(cc, x, contexts, verbose=1)" ] }, diff --git a/notes/Token-Alignment.ipynb b/notes/Token-Alignment.ipynb index c2e89ff4..6f9dc21c 100644 --- a/notes/Token-Alignment.ipynb +++ b/notes/Token-Alignment.ipynb @@ -77,8 +77,8 @@ ")\n", "\n", "A = p.cfg.V\n", - "B = {\"a\", \"aa\", \"aaa\", EOS}\n", - "ϕ = lambda b: \"\".join(b).strip(EOS)" + "B = {'a', 'aa', 'aaa', EOS}\n", + "ϕ = lambda b: ''.join(b).strip(EOS)" ] }, { @@ -252,7 +252,7 @@ "metadata": {}, "outputs": [], "source": [ - "T(\"aaa\", None).epsremove.trim" + "T('aaa', None).epsremove.trim" ] }, { @@ -262,7 +262,7 @@ "metadata": {}, "outputs": [], "source": [ - "T(\"aaa\", None).total_weight()" + "T('aaa', None).total_weight()" ] }, { @@ -315,7 +315,7 @@ "source": [ "display_table(\n", " [[p.cfg.language(100).project(ϕ), generation_tree(graft).D, PL]],\n", - " headings=[\"target\", \"grafting-heuristic\", \"composition\"],\n", + " headings=['target', 'grafting-heuristic', 'composition'],\n", ")" ] }, @@ -386,9 +386,7 @@ "metadata": {}, "outputs": [], "source": [ - "L_PB.assert_equal(\n", - " p.cfg.language(100).project(ϕ)\n", - ") # character-level distribution matches!" + "L_PB.assert_equal(p.cfg.language(100).project(ϕ)) # character-level distribution matches!" ] }, { @@ -400,7 +398,7 @@ "source": [ "display_table(\n", " [[p.cfg.language(100).project(ϕ), L_PB, generation_tree(graft).D, PL]],\n", - " headings=[\"target\", \"pfst\", \"grafting-heuristic\", \"composition\"],\n", + " headings=['target', 'pfst', 'grafting-heuristic', 'composition'],\n", ")" ] }, diff --git a/notes/grammar_processing_issues.ipynb b/notes/grammar_processing_issues.ipynb index 7f4afbed..ff6c5fcc 100644 --- a/notes/grammar_processing_issues.ipynb +++ b/notes/grammar_processing_issues.ipynb @@ -21,8 +21,8 @@ "import sys\n", "import getpass\n", "\n", - "if getpass.getuser() == \"benjamin.lebrun\":\n", - " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")" + "if getpass.getuser() == 'benjamin.lebrun':\n", + " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')" ] }, { @@ -54,12 +54,12 @@ } ], "source": [ - "cfg = LarkStuff(open(\"../benchmark/grammars/sql_case_sensitive.lark\").read()).char_cfg(\n", - " 0.99, ignore=\"[ ]?\"\n", + "cfg = LarkStuff(open('../benchmark/grammars/sql_case_sensitive.lark').read()).char_cfg(\n", + " 0.99, ignore='[ ]?'\n", ")\n", "cfg = locally_normalize(cfg, tol=1e-40, maxiter=np.inf)\n", "\n", - "with timeit(\"boolean EarleyCFGLM preprocessing\"):\n", + "with timeit('boolean EarleyCFGLM preprocessing'):\n", " guide = EarleyBoolMaskCFGLM(cfg) # should take forever" ] }, @@ -126,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "guide.p_next(\"\")" + "guide.p_next('')" ] } ], diff --git a/notes/hfppl.ipynb b/notes/hfppl.ipynb index 9a19b024..17ebdd56 100644 --- a/notes/hfppl.ipynb +++ b/notes/hfppl.ipynb @@ -70,7 +70,7 @@ "# LLM = CachedCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\", auth_token=os.environ['HF_AUTH_TOKEN'])\n", "# LLM = CachedCausalLM.from_pretrained(\"lmsys/vicuna-7b-v1.5\")\n", "# LLM = CachedCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n", - "MODEL_ID = \"codellama/CodeLlama-7b-Instruct-hf\"" + "MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'" ] }, { @@ -176,7 +176,7 @@ "metadata": {}, "outputs": [], "source": [ - "context = prompt + \" SELECT state_color\"" + "context = prompt + ' SELECT state_color'" ] }, { @@ -201,7 +201,6 @@ "\n", "\n", "class GreedilyTokenizedLLM:\n", - "\n", " def __init__(self, llm, tokenizer):\n", " self.tokenizer = tokenizer\n", " self._model = llm\n", @@ -348,7 +347,7 @@ "WS: /[ \\n]/\n", "\n", "\"\"\"\n", - " ).char_cfg(0.99, ignore=\"[ ]?\")\n", + " ).char_cfg(0.99, ignore='[ ]?')\n", ")" ] }, @@ -583,7 +582,7 @@ } ], "source": [ - "print(f\"{sum(len(p.context.tokens) for p in particles)/took:.2f} tokens/sec\")" + "print(f'{sum(len(p.context.tokens) for p in particles)/took:.2f} tokens/sec')" ] }, { diff --git a/notes/hfppl_benleb.ipynb b/notes/hfppl_benleb.ipynb index c323b138..f513d280 100644 --- a/notes/hfppl_benleb.ipynb +++ b/notes/hfppl_benleb.ipynb @@ -30,10 +30,10 @@ "import os\n", "import getpass\n", "\n", - "if getpass.getuser() == \"benjamin.lebrun\":\n", - " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")\n", - " os.environ[\"HF_HOME\"] = os.path.join(os.environ[\"SCRATCH\"], \"hf_cache\")\n", - " print(\"HF cache set; path updated\")\n", + "if getpass.getuser() == 'benjamin.lebrun':\n", + " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')\n", + " os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')\n", + " print('HF cache set; path updated')\n", "\n", "import nest_asyncio\n", "\n", @@ -102,7 +102,7 @@ "from genparse.lm import AsyncGreedilyTokenizedLLM\n", "\n", "genparse_llm = AsyncGreedilyTokenizedLLM.from_name(\n", - " \"codellama/CodeLlama-7b-Instruct-hf\", batch_size=40\n", + " 'codellama/CodeLlama-7b-Instruct-hf', batch_size=40\n", ")" ] }, @@ -116,9 +116,7 @@ "from genparse.cfglm import EarleyBoolMaskCFGLM\n", "from genparse.util import LarkStuff\n", "\n", - "guide = EarleyBoolMaskCFGLM(\n", - " LarkStuff(very_restricted_sql).char_cfg(0.99, ignore=\"[ ]?\")\n", - ")" + "guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(0.99, ignore='[ ]?'))" ] }, { @@ -377,7 +375,7 @@ "particle_approx = sampler.run_inference(\n", " prompt=prompt,\n", " proposal=proposal,\n", - " method=\"smc-standard\",\n", + " method='smc-standard',\n", " n_particles=5,\n", " max_tokens=50,\n", " verbosity=1,\n", diff --git a/notes/hfppl_for_tim_o.ipynb b/notes/hfppl_for_tim_o.ipynb index 3ae97628..c7d13faf 100644 --- a/notes/hfppl_for_tim_o.ipynb +++ b/notes/hfppl_for_tim_o.ipynb @@ -42,14 +42,12 @@ "import os\n", "import getpass\n", "\n", - "if (\n", - " getpass.getuser() == \"benjamin.lebrun\"\n", - "): # change to your user if you want to set these\n", + "if getpass.getuser() == 'benjamin.lebrun': # change to your user if you want to set these\n", " # @TIMO you may need to set this to your local genparse repo\n", - " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")\n", + " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')\n", " # @TIMO also set your cache IF you run into disk quota issues\n", - " os.environ[\"HF_HOME\"] = os.path.join(os.environ[\"SCRATCH\"], \"hf_cache\")\n", - " print(\"HF cache set; path updated\")" + " os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')\n", + " print('HF cache set; path updated')" ] }, { @@ -86,7 +84,7 @@ } ], "source": [ - "MODEL_ID = \"codellama/CodeLlama-7b-Instruct-hf\"\n", + "MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'\n", "hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " MODEL_ID,\n", @@ -158,7 +156,7 @@ " WS: /[ ]/\n", "\n", " \"\"\"\n", - " ).char_cfg(0.99, ignore=\"[ ]?\")\n", + " ).char_cfg(0.99, ignore='[ ]?')\n", ")" ] }, @@ -196,12 +194,15 @@ " self.compare_time = compare_time\n", "\n", " async def step(self):\n", - " (token, llm_prob, guide_prob, proposal_prob) = (\n", - " await self.proposal.sample_next_token(\n", - " context=\"\".join(self.context),\n", - " prompt=self.prompt,\n", - " compare_time=self.compare_time,\n", - " )\n", + " (\n", + " token,\n", + " llm_prob,\n", + " guide_prob,\n", + " proposal_prob,\n", + " ) = await self.proposal.sample_next_token(\n", + " context=''.join(self.context),\n", + " prompt=self.prompt,\n", + " compare_time=self.compare_time,\n", " )\n", " self.context.append(token)\n", " self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)\n", @@ -214,7 +215,7 @@ " return\n", "\n", " def immutable_properties(self):\n", - " return [\"llm\", \"prompt\", \"guide\", \"compare_token\"]\n", + " return ['llm', 'prompt', 'guide', 'compare_token']\n", "\n", " def __repr__(self):\n", " return f\"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}\"" diff --git a/notes/hfppl_jac.ipynb b/notes/hfppl_jac.ipynb index ef3d6081..a1d72e9a 100644 --- a/notes/hfppl_jac.ipynb +++ b/notes/hfppl_jac.ipynb @@ -103,14 +103,14 @@ "\n", "if is_cuda_available():\n", " genparse_llm = AsyncGreedilyTokenizedLLM.from_name(\n", - " \"codellama/CodeLlama-7b-Instruct-hf\", batch_size=40\n", + " 'codellama/CodeLlama-7b-Instruct-hf', batch_size=40\n", " )\n", "\n", "else:\n", " import transformers\n", " from genparse.lm import LLM\n", "\n", - " MODEL_ID = \"gpt2\"\n", + " MODEL_ID = 'gpt2'\n", "\n", " MAX_TOKENS = 100\n", " BATCH_SIZE = 80\n", @@ -133,9 +133,7 @@ "from genparse.cfglm import EarleyBoolMaskCFGLM\n", "from genparse.util import LarkStuff\n", "\n", - "guide = EarleyBoolMaskCFGLM(\n", - " LarkStuff(very_restricted_sql).char_cfg(0.99, ignore=\"[ ]?\")\n", - ")\n", + "guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(0.99, ignore='[ ]?'))\n", "\n", "from genparse.steer import HFPPLSampler\n", "\n", @@ -162,7 +160,7 @@ "particle_approx, record = sampler.run_inference(\n", " prompt=prompt,\n", " proposal=proposal,\n", - " method=\"smc-standard\",\n", + " method='smc-standard',\n", " return_record=True, # use version of smc that keeps a record\n", " n_particles=12,\n", " max_tokens=60,\n", @@ -581,10 +579,10 @@ "READ_EXAMPLE = 0\n", "\n", "if WRITE_EXAMPLE:\n", - " with open(\"hfppl_jac_example_record.pickle\", \"wb\") as h:\n", + " with open('hfppl_jac_example_record.pickle', 'wb') as h:\n", " pickle.dump(record, h, protocol=pickle.HIGHEST_PROTOCOL)\n", "if READ_EXAMPLE:\n", - " with open(\"hfppl_jac_example_record.pickle\", \"rb\") as h:\n", + " with open('hfppl_jac_example_record.pickle', 'rb') as h:\n", " record = pickle.load(h)" ] }, @@ -634,23 +632,23 @@ "color": "rgb(255.0, 0.0, 0.0)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.457921873170436, 11.873882958685238, - 11.385552490469177, - 10.840011287523199, + 11.385552490469175, + 10.8400112875232, 8.105809481441536, 6.851187449489261, 9.827383887592362, 5.318365856516926, 14.561503537917536, 12.287227365169722, - 12.709088234399145, - 12.416948732939439, - 9.190073478275679, + 12.709088234399143, + 12.41694873293944, + 9.19007347827568, 0.35471287475231794, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -744,7 +742,7 @@ "color": "rgb(231.8181818181818, 11.363636363636363, 23.181818181818183)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.457921873170436, 16.823401969323964, @@ -758,14 +756,14 @@ 23.59244301718731, 3.182691823394595, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, 14.421468648915504, - 18.945608596247574, + 18.945608596247578, 15.803314668437338, 1.3912279252057005, 0.33226044433526886, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -886,20 +884,20 @@ "color": "rgb(208.63636363636363, 22.727272727272727, 46.36363636363637)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.913449281118025, 16.002365531418288, 15.182974673123722, 15.012323751039531, - 14.649651227979243, + 14.649651227979245, 15.024711875517289, 16.003574076322664, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, 14.421468648915504, - 18.945608596247574, + 18.945608596247578, 9.965274409008163, 15.803314668437338, 21.903367784151108, @@ -1062,7 +1060,7 @@ "color": "rgb(185.45454545454544, 34.09090909090909, 69.54545454545455)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.457921873170436, 16.823401969323964, @@ -1074,14 +1072,14 @@ 4.5267911018928455, 12.520238575555048, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, - 14.419778171926065, - 10.286253829351077, + 14.419778171926064, + 10.286253829351075, 15.803314668437338, 0.47058298637143886, 0.47058298637143886, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1180,7 +1178,7 @@ "color": "rgb(162.27272727272725, 45.45454545454545, 92.72727272727273)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 6.250132732861252, 3.7720463352716895, @@ -1190,14 +1188,14 @@ 2.853313585048045, 2.878067976123044, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, - 14.419778171926065, - 10.286253829351077, + 14.419778171926064, + 10.286253829351075, 15.803314668437338, 0.47058298637143886, 0.47058298637143886, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1293,27 +1291,27 @@ "color": "rgb(139.09090909090907, 56.81818181818181, 115.90909090909092)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.457921873170436, 11.873882958685238, - 11.385552490469177, + 11.385552490469175, 10.810517375557936, 11.3129664414692, 11.337861545239877, - 7.8391815325920025, + 7.839181532592003, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, - 14.419778171926065, - 10.286253829351073, + 14.419778171926064, + 10.286253829351072, 6.981390596950686, 6.981390596950686, 0.6384270822213649, 15.803314668437338, 0.3287918566041185, 0.3287918566041185, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1420,20 +1418,20 @@ "color": "rgb(115.9090909090909, 68.18181818181819, 139.0909090909091)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.913449281118025, 16.002365531418288, 15.182974673123722, 15.012323751039531, - 14.649651227979243, + 14.649651227979245, 15.598038648562746, 15.785040732464894, 14.037910474841624, - 14.950528730813339, + 14.95052873081334, 15.48159936483578, - 14.419778171926065, - 10.286253829351073, + 14.419778171926064, + 10.286253829351072, 0.6384270822213649, 15.88558286313778, 8.725785645125885, @@ -1442,7 +1440,7 @@ 9.70660501864694, 9.70660501864694, 9.70660501864694, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1548,7 +1546,7 @@ "color": "rgb(92.72727272727272, 79.54545454545455, 162.27272727272728)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 8.257423500494252, 0.32435677673491853, @@ -1557,13 +1555,13 @@ 0.00042230151833260696, 0.0004499470989768492, 0.0004538704330765998, - 14.471657855811605, + 14.471657855811603, 8.085033090031478, 8.738079047115395, - 10.948421124988391, + 10.948421124988393, 1.5528481130671423, 8.725785645125885, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1651,12 +1649,12 @@ "color": "rgb(69.54545454545453, 90.9090909090909, 185.45454545454547)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 16.83717465239317, 20.099662725386494, - 21.443980782304674, - 18.876277927789282, + 21.44398078230467, + 18.87627792778928, 17.883132592668552, 17.590938425201635, 17.707023666499882, @@ -1669,7 +1667,7 @@ 15.88558286313778, 0.0477113839686589, 15.802042939131656, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1768,7 +1766,7 @@ "color": "rgb(46.363636363636346, 102.27272727272727, 208.63636363636365)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 16.83717465239317, 1.1335788721113462, @@ -1788,7 +1786,7 @@ 15.802042939131656, 34.23492522202181, 1.3094098924826658, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -1893,27 +1891,27 @@ "color": "rgb(23.18181818181816, 113.63636363636363, 231.81818181818184)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.913449281118025, 17.60968072558521, - 14.306220277371315, - 14.646364274891125, - 14.679009478542559, + 14.306220277371317, + 14.646364274891123, + 14.67900947854256, 15.878930100964638, 15.968444595340683, 20.956889930483197, 0.004795697456429562, - 14.011882041391393, + 14.011882041391392, 14.922601814341911, 15.600288271375703, - 14.546011331905781, + 14.54601133190578, 10.753511485362427, 0.01727711852679148, 15.802042939131656, - 14.810845859606019, - 14.810845859606019, - 12.314263348816201 + 14.81084585960602, + 14.81084585960602, + 12.3142633488162 ] }, "mode": "markers+text", @@ -2017,26 +2015,26 @@ "color": "rgb(0.0, 125.0, 255.0)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.913449281118025, 17.60968072558521, - 18.071068363285406, + 18.07106836328541, 19.202786969683103, 20.038008516758698, 17.99283310333559, 17.705088506092654, 12.34104206849142, 20.27856636536855, - 14.011882041391393, + 14.011882041391392, 14.922601814341911, 7.906607755569021, 0.37212405200948695, 0.030751094037019283, - 15.803735407440861, - 14.810845859606019, + 15.80373540744086, + 14.81084585960602, 9.754253198735066, - 12.314263348816201 + 12.3142633488162 ] }, "mode": "markers+text", @@ -6804,18 +6802,18 @@ -0.16685074819284473, -0.16685074819284473, -0.16685074819284473, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, -0.5487402987194492, -0.5487402987194492, -0.5487402987194492, @@ -7318,18 +7316,18 @@ -0.16685074819284473, -0.16685074819284473, -0.16685074819284473, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, - -0.46702063795957116, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, + -0.4670206379595712, -0.5487402987194492, -0.5487402987194492, -0.5487402987194492, @@ -8405,13 +8403,13 @@ "import os\n", "\n", "\n", - "def write_images_scrollby(windowsize=10, outdir=\"figs\", height=800, width=1000):\n", + "def write_images_scrollby(windowsize=10, outdir='figs', height=800, width=1000):\n", " record.plotly2(height=height, width=width).write_image(\n", - " os.path.join(outdir, f\"EXAMPLE.png\")\n", + " os.path.join(outdir, f'EXAMPLE.png')\n", " )\n", - " for x_ in range(len(record[\"step\"])):\n", + " for x_ in range(len(record['step'])):\n", " record.plotly2(xrange=[0, x_], height=height, width=width).write_image(\n", - " os.path.join(outdir, f\"TMP{x_:02d}.png\")\n", + " os.path.join(outdir, f'TMP{x_:02d}.png')\n", " )\n", "\n", "\n", @@ -8483,7 +8481,7 @@ "color": "rgb(255.0, 0.0, 0.0)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 13.760187457708865, 14.025232200647846, @@ -8495,7 +8493,7 @@ 17.844293001325273, 18.670061600207696, 20.71745014148054, - 21.078602184484026, + 21.078602184484023, 18.44380438188891, 24.845980647785268, 29.521086659072292, @@ -8506,13 +8504,13 @@ 12.846940967600313, 0.34818164365681875, 0.34818164365681875, - 30.489249691855626, + 30.489249691855623, 0.2749380041770073, 0.2749380041770073, 0.2749380041770073, 0.2749380041770073, 14.145987065255758, - 14.158995761951541, + 14.15899576195154, 21.597664959206565, 21.597664959206565, 21.597664959206565 @@ -8652,7 +8650,7 @@ "color": "rgb(231.8181818181818, 11.363636363636363, 23.181818181818183)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 14.680732327858076, @@ -8660,15 +8658,15 @@ 15.75612883288789, 15.428455754335374, 15.633213078982251, - 15.522114616511987, + 15.522114616511988, 16.881387097221744, 18.517338460472764, 13.815167021437944, - 14.738392715010217, + 14.738392715010216, 17.44838009523997, 10.803526202530191, 14.145987065255758, - 14.158995761951541 + 14.15899576195154 ] }, "mode": "markers+text", @@ -8760,7 +8758,7 @@ "color": "rgb(208.63636363636363, 22.727272727272727, 46.36363636363637)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 13.760187457708865, 14.025232200647846, @@ -8772,13 +8770,13 @@ 17.844293001325273, 18.670061600207696, 20.71745014148054, - 21.078602184484026, + 21.078602184484023, 18.44380438188891, 10.803526202530197, 12.846940967600313, 4.44016009078639, 14.145987065255758, - 14.158995761951541, + 14.15899576195154, 21.597664959206565 ] }, @@ -8884,23 +8882,23 @@ "color": "rgb(185.45454545454544, 34.09090909090909, 69.54545454545455)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 13.760187457708865, 14.025232200647846, 15.285762544554022, - 1.3091077604896781, + 1.309107760489678, 1.3893459153490075, 1.4814736823266503, 1.4201993902173953, 1.5298910258680254, 1.5377711811412484, - 1.6419978926288261, + 1.641997892628826, 1.7529894072396566, 1.9814612006972587, 10.803526202530191, 14.145987065255758, - 14.158995761951541, + 14.15899576195154, 21.597664959206565, 3.109991767313499, 3.109991767313499, @@ -9018,23 +9016,23 @@ "color": "rgb(162.27272727272725, 45.45454545454545, 92.72727272727273)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 15.636948613437644, - 12.683416057537427, + 12.683416057537428, 8.940367197023463, 8.29112162695904, 3.213082656992462, 0.02479088685576839, 0.0023517727922828606, - 7.674656973359717e-05, - 7.842988502138061e-05, - 8.515355953244925e-05, - 4.9729046579742855e-05, - 6.115419749076286e-08, + 0.00007674656973359717, + 0.00007842988502138061, + 0.00008515355953244925, + 0.000049729046579742855, + 6.115419749076286e-8, 10.803526202530191, 14.145987065255758, - 14.158995761951541, + 14.15899576195154, 3.109991767313499 ] }, @@ -9127,7 +9125,7 @@ "color": "rgb(139.09090909090907, 56.81818181818181, 115.90909090909092)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 13.760187457708865, 14.025232200647846, @@ -9233,7 +9231,7 @@ "color": "rgb(115.9090909090909, 68.18181818181819, 139.0909090909091)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 13.760187457708865, 14.025232200647846, @@ -9245,7 +9243,7 @@ 17.844293001325273, 18.670061600207696, 20.71745014148054, - 21.078602184484026, + 21.078602184484023, 18.44380438188891, 17.587349907111197, 10.803526202530191, @@ -9344,20 +9342,20 @@ "color": "rgb(92.72727272727272, 79.54545454545455, 162.27272727272728)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 14.680732327858076, - 10.767303527780113, - 11.362749317687761, + 10.767303527780111, + 11.36274931768776, 11.08863577821977, 11.852253615725274, 11.548798088756843, 11.286983105524376, - 11.346569663258295, + 11.346569663258297, 12.58895119525174, 12.051466346828228, - 9.867192326354381, + 9.86719232635438, 10.803526202530191, 0.2749380041770073, 17.146513534967667, @@ -9463,15 +9461,15 @@ "color": "rgb(69.54545454545453, 90.9090909090909, 185.45454545454547)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 13.340755499216302, 12.625070349598664, 12.693353402446863, 12.557499814426562, - 12.734033603701935, - 12.882704181997719, + 12.734033603701937, + 12.88270418199772, 6.901263631086172, 7.492526579923725, 7.5617941623819585, @@ -9605,7 +9603,7 @@ "color": "rgb(46.363636363636346, 102.27272727272727, 208.63636363636365)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 14.680732327858076, @@ -9617,12 +9615,12 @@ 15.337821505252352, 16.017845068085414, 17.77414156136321, - 18.390048709064374, - 21.773526627488554, - 3.4148849558775534, + 18.390048709064377, + 21.77352662748855, + 3.414884955877554, 17.587349907111193, 2.5699505429676868, - 3.5948764828277963e-09, + 3.5948764828277963e-9, 2.1818119524102917, 17.587349907111193, 8.966869735242355, @@ -9765,7 +9763,7 @@ "color": "rgb(23.18181818181816, 113.63636363636363, 231.81818181818184)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 14.680732327858076, @@ -9890,14 +9888,14 @@ "color": "rgb(0.0, 125.0, 255.0)", "opacity": 0.15, "size": [ - 14.142135623730951, + 14.142135623730953, 14.142135623730953, 14.193731267149062, 14.680732327858076, 14.87843464250492, 15.75612883288789, 15.428455754335374, - 15.622299546289641, + 15.62229954628964, 15.909286726897806, 16.897293160605635, 13.44411451756642, @@ -17079,12 +17077,12 @@ "WINDOWSIZE = 10\n", "\n", "d__ = d_\n", - "d__[\"context\"].apply(lambda r: \"\".join(r[:-1]))\n", - "d__[\"overflow_context\"] = d__.apply(\n", - " lambda r: \"\".join(r[\"context\"][: r[\"step\"] - (WINDOWSIZE + 1)]), axis=1\n", + "d__['context'].apply(lambda r: ''.join(r[:-1]))\n", + "d__['overflow_context'] = d__.apply(\n", + " lambda r: ''.join(r['context'][: r['step'] - (WINDOWSIZE + 1)]), axis=1\n", ").to_list()\n", - "d__[\"window_context\"] = d__.apply(\n", - " lambda r: \"\".join(r[\"context\"][r[\"step\"] - WINDOWSIZE : r[\"step\"]]), axis=1\n", + "d__['window_context'] = d__.apply(\n", + " lambda r: ''.join(r['context'][r['step'] - WINDOWSIZE : r['step']]), axis=1\n", ").to_list()\n", "d__" ] @@ -17269,7 +17267,7 @@ "size": [ 14.142135623730953, 15.636948613437644, - 12.683416057537427 + 12.683416057537428 ] }, "mode": "markers+text", @@ -19371,7 +19369,7 @@ ], [ " SELECT", - -12.044136238968767, + -12.044136238968768, false ], [ @@ -19386,7 +19384,7 @@ ], [ " SELECT age FROM data", - -12.890518994730765, + -12.890518994730764, false ], [ @@ -19457,13 +19455,13 @@ 1, 0.9999999999999996, 0.08345997096712078, - 3.144573361448564e-05, - 2.4956886495467274e-05, - 2.1462501915942082e-05, - 1.6268849815188344e-05, - 1.4757300352087234e-05, - 1.1698653730016005e-05, - 1.4753454812917209e-05, + 0.00003144573361448564, + 0.00002495688649546727, + 0.00002146250191594208, + 0.000016268849815188344, + 0.000014757300352087234, + 0.000011698653730016005, + 0.000014753454812917207, 0.0017675764341983872, 0.3297392910493373, 0.43085236169137026, @@ -19567,7 +19565,7 @@ ], [ " SELECT vote, zip", - -0.9403565687817279, + -0.940356568781728, false ], [ @@ -19577,7 +19575,7 @@ ], [ " SELECT vote, zipcode FROM", - -1.1918328268726919, + -1.191832826872692, false ], [ @@ -19722,7 +19720,7 @@ 1.165703252685702, 1.341546047879836, 1.658581020724092, - 1.8327405829377286, + 1.8327405829377288, 1.9151469672615904, 1.776976828408179, 1.4601356587730292, @@ -19734,7 +19732,7 @@ 0.000722149296949957, 0.005201221340267852, 1, - 3.6494517184788635, + 3.649451718478863, 0.0004472917528217814, 9.995974374224588, 0.0004472917528217814, @@ -19888,7 +19886,7 @@ ], [ " SELECT vote, vote, zip", - -1.6492045188518663, + -1.6492045188518665, false ], [ @@ -19996,16 +19994,16 @@ 1.0000000000000002, 1.165703252685702, 1.341546047879836, - 1.1054502310927365, + 1.1054502310927363, 0.9546201682781362, 1.0306565283252629, 1.124728662169629, - 0.9528559020218703, + 0.9528559020218704, 1.14477607748672, 1.717334448386174, 2.816856935026382, 0.7772670745842045, - 2.349754756014147e-09, + 2.349754756014147e-9, 0.89473665260685, 1.0842783427146367, 0.7606984589293397, @@ -20129,7 +20127,7 @@ ], [ " SELECT age FROM data", - -1.4672590691116039, + -1.467259069111604, false ], [ @@ -20239,7 +20237,7 @@ 1.0821068550913535, 1.0732693969922604, 1.052737328989911, - 0.9468694342361799, + 0.94686943423618, 0.3440935907206618, 0.5734037698910691, 1.0842783427146367, @@ -20355,7 +20353,7 @@ ], [ " SELECT vote FROM", - -1.8979466102495361, + -1.897946610249536, false ], [ @@ -20470,9 +20468,9 @@ 0.5642008988781192, 0.5268289341512757, 0.5341759231035361, - 0.49153354750180955, - 0.22901753434443822, - 1.0996645104191986e-06, + 0.4915335475018095, + 0.22901753434443825, + 1.0996645104191986e-6, 1.0842783427146367, 0.7606984589293397, 1.461192198434121, @@ -20690,7 +20688,7 @@ 1, 1.0000000000000002, 1.2403163373321546, - 1.4698766961409617, + 1.4698766961409615, 1.5713645707446335, 1.6673103705955654, 1.6201204504974391, @@ -20802,7 +20800,7 @@ ], [ " SELECT state", - -3.8397796375923234, + -3.839779637592323, false ], [ @@ -20847,7 +20845,7 @@ ], [ " SELECT age, gender FROM data ORDER BY age ASC ", - -0.9458901196953537, + -0.9458901196953536, false ], [ @@ -21038,7 +21036,7 @@ ], [ " SELECT age, gender FROM data ORDER BY age ASC ", - -1.6366235684784727, + -1.636623568478473, false ], [ @@ -21430,7 +21428,7 @@ ], [ " SELECT age, gender FROM data ORDER BY age ASC Context: {row['context_string']}
Weight: {row['weight']}
Resample?: {row['resample?']}\",\n", " axis=1,\n", " ),\n", - " textposition=\"top center\",\n", + " textposition='top center',\n", ")\n", "\n", "fig = go.Figure(data=[scatter])\n", "\n", "# Define frames\n", "frames = []\n", - "steps = sorted(recs[\"step\"].unique())\n", + "steps = sorted(recs['step'].unique())\n", "for i, step in enumerate(steps):\n", " # Filter data up to the current step\n", - " data_up_to_step = recs[recs[\"step\"] <= step]\n", + " data_up_to_step = recs[recs['step'] <= step]\n", "\n", " scatter = go.Scatter(\n", - " x=data_up_to_step[\"step\"],\n", - " y=data_up_to_step[\"particle\"],\n", - " mode=\"markers+text\",\n", + " x=data_up_to_step['step'],\n", + " y=data_up_to_step['particle'],\n", + " mode='markers+text',\n", " marker=dict(\n", - " size=data_up_to_step[\"prop_exp_weight\"],\n", - " color=data_up_to_step[\"resampled as\"].map(color_map),\n", + " size=data_up_to_step['prop_exp_weight'],\n", + " color=data_up_to_step['resampled as'].map(color_map),\n", " opacity=0.6,\n", " ),\n", - " text=data_up_to_step[\"token\"],\n", - " hoverinfo=\"text\",\n", + " text=data_up_to_step['token'],\n", + " hoverinfo='text',\n", " hovertext=data_up_to_step.apply(\n", " lambda row: f\"Token: {row['token']}
Context: {row['context_string']}
Weight: {row['weight']}
Resample?: {row['resample?']}\",\n", " axis=1,\n", " ),\n", - " textposition=\"top center\",\n", + " textposition='top center',\n", " )\n", "\n", " resampling_lines = []\n", - " for resampled_as in recs[\"resampled as\"].unique():\n", - " resampled_as_data = recs[recs[\"resampled as\"] == resampled_as]\n", + " for resampled_as in recs['resampled as'].unique():\n", + " resampled_as_data = recs[recs['resampled as'] == resampled_as]\n", "\n", " # Add resampling lines for the current step\n", - " for _, row in resampled_as_data[resampled_as_data[\"step\"] == step].iterrows():\n", + " for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():\n", " resampling_lines.append(\n", " go.Scatter(\n", - " x=[row[\"step\"], row[\"step\"] + 1],\n", - " y=[row[\"resampled as\"], row[\"particle\"]],\n", - " mode=\"lines\",\n", + " x=[row['step'], row['step'] + 1],\n", + " y=[row['resampled as'], row['particle']],\n", + " mode='lines',\n", " line=dict(color=color_map[resampled_as]),\n", " opacity=0.3,\n", " name=resampled_as,\n", @@ -221311,13 +221311,13 @@ " )\n", "\n", " for s in steps[: i + 1]:\n", - " for _, row in resampled_as_data[resampled_as_data[\"step\"] == s].iterrows():\n", + " for _, row in resampled_as_data[resampled_as_data['step'] == s].iterrows():\n", " resampling_lines.append(\n", " go.Scatter(\n", - " x=[row[\"step\"], row[\"step\"] + 1],\n", - " y=[row[\"particle\"], row[\"particle\"]],\n", - " mode=\"lines\",\n", - " line=dict(color=\"gray\"),\n", + " x=[row['step'], row['step'] + 1],\n", + " y=[row['particle'], row['particle']],\n", + " mode='lines',\n", + " line=dict(color='gray'),\n", " opacity=0.2,\n", " name=resampled_as,\n", " showlegend=False,\n", @@ -221329,12 +221329,12 @@ " if step in resample_steps:\n", " vertical_lines.append(\n", " go.layout.Shape(\n", - " type=\"line\",\n", + " type='line',\n", " x0=step,\n", " x1=step,\n", " y0=0,\n", " y1=1,\n", - " line=dict(width=4, color=\"gray\"),\n", + " line=dict(width=4, color='gray'),\n", " opacity=0.15,\n", " )\n", " )\n", @@ -221351,74 +221351,74 @@ " width=1200,\n", " height=500,\n", " xaxis=dict(range=[-1, 36]),\n", - " yaxis=dict(type=\"category\"), # Ensure the y-axis is categorical\n", - " plot_bgcolor=\"#fff\",\n", + " yaxis=dict(type='category'), # Ensure the y-axis is categorical\n", + " plot_bgcolor='#fff',\n", " showlegend=False,\n", " updatemenus=[\n", " {\n", - " \"buttons\": [\n", + " 'buttons': [\n", " {\n", - " \"args\": [\n", + " 'args': [\n", " None,\n", " {\n", - " \"frame\": {\"duration\": 500, \"redraw\": True},\n", - " \"fromcurrent\": True,\n", + " 'frame': {'duration': 500, 'redraw': True},\n", + " 'fromcurrent': True,\n", " },\n", " ],\n", - " \"label\": \"Play\",\n", - " \"method\": \"animate\",\n", + " 'label': 'Play',\n", + " 'method': 'animate',\n", " },\n", " {\n", - " \"args\": [\n", + " 'args': [\n", " [None],\n", " {\n", - " \"frame\": {\"duration\": 0, \"redraw\": True},\n", - " \"mode\": \"immediate\",\n", - " \"transition\": {\"duration\": 0},\n", + " 'frame': {'duration': 0, 'redraw': True},\n", + " 'mode': 'immediate',\n", + " 'transition': {'duration': 0},\n", " },\n", " ],\n", - " \"label\": \"Pause\",\n", - " \"method\": \"animate\",\n", + " 'label': 'Pause',\n", + " 'method': 'animate',\n", " },\n", " ],\n", - " \"direction\": \"left\",\n", - " \"pad\": {\"r\": 10, \"t\": 87},\n", - " \"showactive\": False,\n", - " \"type\": \"buttons\",\n", - " \"x\": 0.1,\n", - " \"xanchor\": \"right\",\n", - " \"y\": 0,\n", - " \"yanchor\": \"top\",\n", + " 'direction': 'left',\n", + " 'pad': {'r': 10, 't': 87},\n", + " 'showactive': False,\n", + " 'type': 'buttons',\n", + " 'x': 0.1,\n", + " 'xanchor': 'right',\n", + " 'y': 0,\n", + " 'yanchor': 'top',\n", " }\n", " ],\n", " sliders=[\n", " {\n", - " \"steps\": [\n", + " 'steps': [\n", " {\n", - " \"args\": [\n", + " 'args': [\n", " [str(step)],\n", " {\n", - " \"frame\": {\"duration\": 300, \"redraw\": True},\n", - " \"mode\": \"immediate\",\n", - " \"transition\": {\"duration\": 300},\n", + " 'frame': {'duration': 300, 'redraw': True},\n", + " 'mode': 'immediate',\n", + " 'transition': {'duration': 300},\n", " },\n", " ],\n", - " \"label\": str(step),\n", - " \"method\": \"animate\",\n", + " 'label': str(step),\n", + " 'method': 'animate',\n", " }\n", " for step in steps\n", " ],\n", - " \"x\": 0.1,\n", - " \"xanchor\": \"left\",\n", - " \"y\": 0,\n", - " \"yanchor\": \"top\",\n", - " \"currentvalue\": {\n", - " \"font\": {\"size\": 20},\n", - " \"prefix\": \"Step:\",\n", - " \"visible\": True,\n", - " \"xanchor\": \"right\",\n", + " 'x': 0.1,\n", + " 'xanchor': 'left',\n", + " 'y': 0,\n", + " 'yanchor': 'top',\n", + " 'currentvalue': {\n", + " 'font': {'size': 20},\n", + " 'prefix': 'Step:',\n", + " 'visible': True,\n", + " 'xanchor': 'right',\n", " },\n", - " \"transition\": {\"duration\": 300, \"easing\": \"cubic-in-out\"},\n", + " 'transition': {'duration': 300, 'easing': 'cubic-in-out'},\n", " }\n", " ],\n", ")\n", diff --git a/notes/sql_debug.ipynb b/notes/sql_debug.ipynb index 647a7ec6..dcb0112b 100644 --- a/notes/sql_debug.ipynb +++ b/notes/sql_debug.ipynb @@ -21,8 +21,8 @@ "import sys\n", "import getpass\n", "\n", - "if getpass.getuser() == \"benjamin.lebrun\":\n", - " sys.path.append(\"/home/mila/b/benjamin.lebrun/genparse\")" + "if getpass.getuser() == 'benjamin.lebrun':\n", + " sys.path.append('/home/mila/b/benjamin.lebrun/genparse')" ] }, { @@ -45,7 +45,7 @@ "source": [ "from genparse.evaluation.dataset import Dataset\n", "\n", - "dataset = Dataset(\"spider\", \"validation\")" + "dataset = Dataset('spider', 'validation')" ] }, { @@ -55,9 +55,9 @@ "metadata": {}, "outputs": [], "source": [ - "cfg = LarkStuff(\n", - " open(\"../benchmark/grammars/sql_case_insensitive.lark\").read()\n", - ").char_cfg(0.99, ignore=\"[ ]?\")\n", + "cfg = LarkStuff(open('../benchmark/grammars/sql_case_insensitive.lark').read()).char_cfg(\n", + " 0.99, ignore='[ ]?'\n", + ")\n", "guide = EarleyBoolMaskCFGLM(cfg)" ] },