Skip to content

Commit

Permalink
format ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
benlipkin committed Jun 15, 2024
1 parent 580d414 commit 491137e
Show file tree
Hide file tree
Showing 15 changed files with 2,036 additions and 2,038 deletions.
16 changes: 8 additions & 8 deletions notes/A-Tale-of-Two-Transducers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
}
],
"source": [
"D = Delta(\"ab\")\n",
"D = Delta('ab')\n",
"D"
]
},
Expand Down Expand Up @@ -331,7 +331,7 @@
}
],
"source": [
"(D @ WFSA.from_string(\"ab\", Float)).trim"
"(D @ WFSA.from_string('ab', Float)).trim"
]
},
{
Expand Down Expand Up @@ -465,7 +465,7 @@
}
],
"source": [
"(cfg @ Delta1(\"ab\")).nullaryremove(binarize=False).trim()"
"(cfg @ Delta1('ab')).nullaryremove(binarize=False).trim()"
]
},
{
Expand Down Expand Up @@ -664,7 +664,7 @@
}
],
"source": [
"derivative(\"ab\", \"ab\").graphviz(fmt_node=lambda x: str(x))"
"derivative('ab', 'ab').graphviz(fmt_node=lambda x: str(x))"
]
},
{
Expand All @@ -688,7 +688,7 @@
}
],
"source": [
"(cfg @ derivative(\"a\", \"ab\") @ derivative(\"a\", \"ab\")).language(10)"
"(cfg @ derivative('a', 'ab') @ derivative('a', 'ab')).language(10)"
]
},
{
Expand All @@ -712,7 +712,7 @@
}
],
"source": [
"(cfg @ derivative(\"aa\", \"ab\")).language(10)"
"(cfg @ derivative('aa', 'ab')).language(10)"
]
},
{
Expand Down Expand Up @@ -748,7 +748,7 @@
}
],
"source": [
"cfg.derivative(\"a\").trim()"
"cfg.derivative('a').trim()"
]
},
{
Expand Down Expand Up @@ -778,7 +778,7 @@
}
],
"source": [
"(cfg @ derivative(\"a\", cfg.V)).nullaryremove(binarize=False).unaryremove().trim()"
"(cfg @ derivative('a', cfg.V)).nullaryremove(binarize=False).unaryremove().trim()"
]
},
{
Expand Down
20 changes: 10 additions & 10 deletions notes/Character-at-a-Time.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"metadata": {},
"outputs": [],
"source": [
"llm = GreedilyTokenizedLLM(\"gpt2\")"
"llm = GreedilyTokenizedLLM('gpt2')"
]
},
{
Expand All @@ -57,7 +57,7 @@
}
],
"source": [
"prompt = \"Hello my name is\"\n",
"prompt = 'Hello my name is'\n",
"pp = llm.p_next(prompt, 10).normalize()\n",
"pp"
]
Expand Down Expand Up @@ -174,7 +174,7 @@
}
],
"source": [
"print(repr(\"\".join(pcfg.sample())))"
"print(repr(''.join(pcfg.sample())))"
]
},
{
Expand Down Expand Up @@ -206,7 +206,7 @@
"tracer = TraceSWOR()\n",
"for _ in range(1):\n",
" with tracer:\n",
" print(\"----------------------------------\")\n",
" print('----------------------------------')\n",
" ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer, verbosity=1)\n",
" print(ys)"
]
Expand Down Expand Up @@ -1486,7 +1486,7 @@
],
"source": [
"tracer.root.graphviz(\n",
" fmt_node=lambda x: f\"{x._mass:.3g}\", fmt_edge=lambda i, a, j: repr(a)\n",
" fmt_node=lambda x: f'{x._mass:.3g}', fmt_edge=lambda i, a, j: repr(a)\n",
")"
]
},
Expand Down Expand Up @@ -1524,7 +1524,7 @@
"ADJ: \"fruit\"\n",
"\n",
"\"\"\"\n",
" ).char_cfg(0.99, ignore=\"[ ]?\"),\n",
" ).char_cfg(0.99, ignore='[ ]?'),\n",
" tol=1e-100,\n",
" )\n",
")"
Expand All @@ -1548,7 +1548,7 @@
}
],
"source": [
"\"\".join(fruit.sample())"
"''.join(fruit.sample())"
]
},
{
Expand Down Expand Up @@ -1579,7 +1579,7 @@
}
],
"source": [
"fruit(\"fruit flies like a banana \" + EOS)"
"fruit('fruit flies like a banana ' + EOS)"
]
},
{
Expand All @@ -1599,12 +1599,12 @@
}
],
"source": [
"prompt = \"The following is a favorite sentence among linguists:\"\n",
"prompt = 'The following is a favorite sentence among linguists:'\n",
"token_trie_approx = TokenTrieApproximation(llm, fruit)\n",
"tracer = TraceSWOR()\n",
"for _ in range(1):\n",
" with tracer:\n",
" print(\"----------------------------------\")\n",
" print('----------------------------------')\n",
" ys = token_trie_approx.sample(prompt, max_tokens=50, draw=tracer)\n",
" print(ys)"
]
Expand Down
50 changes: 24 additions & 26 deletions notes/FST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"metadata": {},
"outputs": [],
"source": [
"b2c = H.fst.prune_to_alphabet(None, foo.V | {\"\"}).renumber"
"b2c = H.fst.prune_to_alphabet(None, foo.V | {''}).renumber"
]
},
{
Expand All @@ -174,10 +174,10 @@
"outputs": [],
"source": [
"for x in foo.cnf.language(3):\n",
" display(HTML(\"<hr/>\"))\n",
" display(HTML('<hr/>'))\n",
" print(x)\n",
" bpe_x = b2c(None, x).epsremove.trim\n",
" print(\"total weight of BPE sequences (i.e., ambiguity):\", bpe_x.total_weight())\n",
" print('total weight of BPE sequences (i.e., ambiguity):', bpe_x.total_weight())\n",
" display(bpe_x)\n",
" print()"
]
Expand Down Expand Up @@ -253,7 +253,7 @@
"for x, w in foo.cnf.language(L + 2).items():\n",
" if len(x) > L:\n",
" continue\n",
" cc[\"\".join(x)] += w\n",
" cc[''.join(x)] += w\n",
"# cc"
]
},
Expand Down Expand Up @@ -426,7 +426,7 @@
"df = []\n",
"for x, w in sorted(normalize(lm2.p_next(context)).items(), key=lambda kv: -kv[1]):\n",
" df.append((x, (H.tokenizer.decode([x]) if x != EOS else EOS), w))\n",
"pd.DataFrame(df, columns=[\"token_id\", \"chars\", \"prob\"]).set_index(\"token_id\")"
"pd.DataFrame(df, columns=['token_id', 'chars', 'prob']).set_index('token_id')"
]
},
{
Expand Down Expand Up @@ -546,9 +546,9 @@
"source": [
"trace = TraceSWOR()\n",
"for _ in range(15):\n",
" print(\"mass=\", trace.root.mass)\n",
" print('mass=', trace.root.mass)\n",
" with trace:\n",
" print(\"\".join(lm.sample(draw=trace)))"
" print(''.join(lm.sample(draw=trace)))"
]
},
{
Expand Down Expand Up @@ -576,7 +576,7 @@
"metadata": {},
"outputs": [],
"source": [
"c2t = lark_stuff.transducer(ignore=\"\", decay=0.0125)\n",
"c2t = lark_stuff.transducer(ignore='', decay=0.0125)\n",
"len(c2t.states)"
]
},
Expand All @@ -595,7 +595,7 @@
"metadata": {},
"outputs": [],
"source": [
"x = \"SELECT * FROM data\""
"x = 'SELECT * FROM data'"
]
},
{
Expand Down Expand Up @@ -711,7 +711,7 @@
"metadata": {},
"outputs": [],
"source": [
"cfg_t(\"SELECT * FROM data </s>\")"
"cfg_t('SELECT * FROM data </s>')"
]
},
{
Expand All @@ -721,7 +721,7 @@
"metadata": {},
"outputs": [],
"source": [
"cfg_t(\"SELECT * FROM data </s>\")"
"cfg_t('SELECT * FROM data </s>')"
]
},
{
Expand All @@ -742,7 +742,7 @@
"outputs": [],
"source": [
"for _ in range(10):\n",
" print(\"\".join(lm.sample()))"
" print(''.join(lm.sample()))"
]
},
{
Expand All @@ -752,7 +752,7 @@
"metadata": {},
"outputs": [],
"source": [
"lm.p_next(\"SELECT * FROM \")"
"lm.p_next('SELECT * FROM ')"
]
},
{
Expand Down Expand Up @@ -800,7 +800,7 @@
"metadata": {},
"outputs": [],
"source": [
"x = \"SELECT * FROM data\"\n",
"x = 'SELECT * FROM data'\n",
"b = tokenizer.encode(x)\n",
"b"
]
Expand All @@ -822,7 +822,7 @@
"metadata": {},
"outputs": [],
"source": [
"with timeit(\"composition\"):\n",
"with timeit('composition'):\n",
" c = FST.from_string(tuple(b), Float) @ b2c\n",
"about(c)"
]
Expand Down Expand Up @@ -879,7 +879,7 @@
"metadata": {},
"outputs": [],
"source": [
"x = x = \"SELECT * FROM data\""
"x = x = 'SELECT * FROM data'"
]
},
{
Expand All @@ -889,9 +889,9 @@
"metadata": {},
"outputs": [],
"source": [
"with timeit(\"composition\"):\n",
"with timeit('composition'):\n",
" bs = b2c @ FST.from_string(x, Float)\n",
"with timeit(\"trim\"):\n",
"with timeit('trim'):\n",
" bs.trim\n",
"about(bs)"
]
Expand Down Expand Up @@ -1025,7 +1025,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(\"\".join(lm.sample()))"
"print(''.join(lm.sample()))"
]
},
{
Expand Down Expand Up @@ -1053,9 +1053,7 @@
"metadata": {},
"outputs": [],
"source": [
"bpe_lm = CharAlignedCFGLM(\n",
" lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token\n",
")"
"bpe_lm = CharAlignedCFGLM(lm=lm, words={x for _, x in H.pairs}, eos=H.tokenizer.eos_token)"
]
},
{
Expand All @@ -1065,7 +1063,7 @@
"metadata": {},
"outputs": [],
"source": [
"lm.p_next(\"\")"
"lm.p_next('')"
]
},
{
Expand All @@ -1083,7 +1081,7 @@
"metadata": {},
"outputs": [],
"source": [
"bpe_lm.p_next(\"\")"
"bpe_lm.p_next('')"
]
},
{
Expand All @@ -1093,7 +1091,7 @@
"metadata": {},
"outputs": [],
"source": [
"lm.p_next(\"SELECT \")"
"lm.p_next('SELECT ')"
]
},
{
Expand All @@ -1103,7 +1101,7 @@
"metadata": {},
"outputs": [],
"source": [
"bpe_lm.p_next(\"SELECT \")"
"bpe_lm.p_next('SELECT ')"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions notes/Inference-Playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
],
"source": [
"ref = BruteForceGlobalProductOfExperts(lm1, lm2, MAX_LENGTH)\n",
"ref.target.project(\"\".join)"
"ref.target.project(''.join)"
]
},
{
Expand All @@ -165,7 +165,7 @@
" lm2,\n",
" MAX_LENGTH=MAX_LENGTH,\n",
" n_particles=N_PARTICLES,\n",
" METHOD=\"is\",\n",
" METHOD='is',\n",
" # METHOD = 'smc',\n",
")"
]
Expand Down Expand Up @@ -258,8 +258,8 @@
}
],
"source": [
"w.project(\"\".join).trim().compare(ref.target.project(\"\".join).trim()).sort_values(\n",
" \"key\", ascending=False\n",
"w.project(''.join).trim().compare(ref.target.project(''.join).trim()).sort_values(\n",
" 'key', ascending=False\n",
")"
]
},
Expand Down Expand Up @@ -328,7 +328,7 @@
"# truncate the reference distribution to the support set of the sample;\n",
"# renamed the keys to handle the minor discrepancy in the EOS symbol\n",
"tmp = ref.target.filter(lambda k: k[:-1] in R).normalize().sort()\n",
"tmp.project(lambda k: k[:-1]).compare(R).sort_values(\"key\")"
"tmp.project(lambda k: k[:-1]).compare(R).sort_values('key')"
]
},
{
Expand Down
Loading

0 comments on commit 491137e

Please sign in to comment.