Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 2, 2024
1 parent 3b1dbe4 commit d7e9a9b
Show file tree
Hide file tree
Showing 81 changed files with 12,125 additions and 9,644 deletions.
26 changes: 13 additions & 13 deletions scripts/_mkl/notebooks/00a - Types.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|default_exp types"
"# |default_exp types"
]
},
{
Expand All @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"from typing import Any, NamedTuple\n",
"import numpy as np\n",
"import jax\n",
Expand All @@ -29,18 +29,18 @@
"Int = Array\n",
"FaceIndex = int\n",
"FaceIndices = Array\n",
"ArrayN = Array\n",
"Array3 = Array\n",
"Array2 = Array\n",
"ArrayNx2 = Array\n",
"ArrayNx3 = Array\n",
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
"PrecisionMatrix = Matrix\n",
"ArrayN = Array\n",
"Array3 = Array\n",
"Array2 = Array\n",
"ArrayNx2 = Array\n",
"ArrayNx3 = Array\n",
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
"PrecisionMatrix = Matrix\n",
"CovarianceMatrix = Matrix\n",
"CholeskyMatrix = Matrix\n",
"SquareMatrix = Matrix\n",
"Vector = Array\n",
"Direction = Vector\n",
"CholeskyMatrix = Matrix\n",
"SquareMatrix = Matrix\n",
"Vector = Array\n",
"Direction = Vector\n",
"BaseVector = Vector"
]
},
Expand Down
121 changes: 70 additions & 51 deletions scripts/_mkl/notebooks/00b - Utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|default_exp utils"
"# |default_exp utils"
]
},
{
Expand All @@ -22,9 +22,9 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.collections import LineCollection\n",
"from matplotlib.collections import LineCollection\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -44,8 +44,8 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"key = jax.random.PRNGKey(0)\n",
"# |export\n",
"key = jax.random.PRNGKey(0)\n",
"logsumexp = jax.scipy.special.logsumexp"
]
},
Expand All @@ -55,18 +55,21 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"def keysplit(key, *ns):\n",
" if len(ns) == 0: \n",
" if len(ns) == 0:\n",
" return jax.random.split(key, 1)[0]\n",
" elif len(ns) == 1:\n",
" n, = ns\n",
" if n == 1: return keysplit(key)\n",
" else: return jax.random.split(key, ns[0])\n",
" (n,) = ns\n",
" if n == 1:\n",
" return keysplit(key)\n",
" else:\n",
" return jax.random.split(key, ns[0])\n",
" else:\n",
" keys = []\n",
" for n in ns: keys.append(keysplit(key, n))\n",
" return keys\n"
" for n in ns:\n",
" keys.append(keysplit(key, n))\n",
" return keys"
]
},
{
Expand Down Expand Up @@ -122,13 +125,15 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"def bounding_box(arr, pad=0):\n",
" \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n",
" return jnp.array([\n",
" [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n",
" [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n",
" ])"
" return jnp.array(\n",
" [\n",
" [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n",
" [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n",
" ]\n",
" )"
]
},
{
Expand All @@ -137,24 +142,27 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"def argmax_axes(a, axes=None):\n",
" \"\"\"Argmax along specified axes\"\"\"\n",
" if axes is None: return jnp.argmax(a)\n",
" \n",
" n = len(axes) \n",
" axes_ = set(range(a.ndim))\n",
" if axes is None:\n",
" return jnp.argmax(a)\n",
"\n",
" n = len(axes)\n",
" axes_ = set(range(a.ndim))\n",
" axes_0 = axes\n",
" axes_1 = sorted(axes_ - set(axes_0)) \n",
" axes_ = axes_0 + axes_1\n",
" axes_1 = sorted(axes_ - set(axes_0))\n",
" axes_ = axes_0 + axes_1\n",
"\n",
" b = jnp.transpose(a, axes=axes_)\n",
" c = b.reshape(np.prod(b.shape[:n]), -1)\n",
"\n",
" I = jnp.argmax(c, axis=0)\n",
" I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n",
" I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n",
" b.shape[n:] + (n,)\n",
" )\n",
"\n",
" return I"
" return I"
]
},
{
Expand All @@ -177,7 +185,7 @@
"test_shape = (3, 99, 5, 9)\n",
"a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n",
"\n",
"I = argmax_axes(a, axes=[0,1])\n",
"I = argmax_axes(a, axes=[0, 1])\n",
"I.shape"
]
},
Expand All @@ -194,9 +202,13 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n",
"def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])"
"# |export\n",
"def cam_to_screen(x):\n",
" return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n",
"\n",
"\n",
"def screen_to_cam(y):\n",
" return y[2] * jnp.array([y[0], y[1], 1.0])"
]
},
{
Expand All @@ -205,24 +217,26 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"def rot2d(hd): return jnp.array([\n",
" [jnp.cos(hd), -jnp.sin(hd)], \n",
" [jnp.sin(hd), jnp.cos(hd)]\n",
" ]);\n",
"# |export\n",
"def rot2d(hd):\n",
" return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n",
"\n",
"\n",
"def pack_2dpose(x,hd): \n",
" return jnp.concatenate([x,jnp.array([hd])])\n",
"def pack_2dpose(x, hd):\n",
" return jnp.concatenate([x, jnp.array([hd])])\n",
"\n",
"def apply_2dpose(p, ys): \n",
" return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n",
"\n",
"def unit_vec(hd): \n",
"def apply_2dpose(p, ys):\n",
" return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n",
"\n",
"\n",
"def unit_vec(hd):\n",
" return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n",
"\n",
"\n",
"def adjust_angle(hd):\n",
" \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n",
" return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi"
" return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi"
]
},
{
Expand All @@ -238,12 +252,12 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"from genjax.incremental import UnknownChange, NoChange, Diff\n",
"\n",
"\n",
"def argdiffs(args, other=None):\n",
" return tuple(map(lambda v: Diff(v, UnknownChange), args))\n"
" return tuple(map(lambda v: Diff(v, UnknownChange), args))"
]
},
{
Expand All @@ -252,18 +266,18 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"from builtins import property as _property, tuple as _tuple\n",
"from typing import Any\n",
"\n",
"\n",
"class Args(tuple):\n",
" def __new__(cls, *args, **kwargs):\n",
" return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n",
" \n",
"\n",
" def __init__(self, *args, **kwargs):\n",
" self._d = dict()\n",
" for k,v in kwargs.items():\n",
" for k, v in kwargs.items():\n",
" self._d[k] = v\n",
" setattr(self, k, v)\n",
"\n",
Expand Down Expand Up @@ -297,30 +311,35 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# \n",
"# |export\n",
"#\n",
"# Monkey patching `sample` for `BuiltinGenerativeFunction`\n",
"# \n",
"#\n",
"cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n",
"\n",
"\n",
"def genjax_sample(self, key, *args, **kwargs):\n",
" tr = self.simulate(key, args)\n",
" return tr.get_retval()\n",
"\n",
"\n",
"setattr(cls, \"sample\", genjax_sample)\n",
"\n",
"\n",
"# \n",
"#\n",
"# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n",
"# \n",
"#\n",
"cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n",
"\n",
"\n",
"def deff_gen_func_call(self, key, **kwargs):\n",
" return self.gen_fn.sample(key, *self.args, **kwargs)\n",
"\n",
"\n",
"def deff_gen_func_logpdf(self, x, **kwargs):\n",
" return self.gen_fn.logpdf(x, *self.args, **kwargs)\n",
"\n",
"\n",
"setattr(cls, \"__call__\", deff_gen_func_call)\n",
"setattr(cls, \"sample\", deff_gen_func_call)\n",
"setattr(cls, \"logpdf\", deff_gen_func_logpdf)"
Expand Down
Loading

0 comments on commit d7e9a9b

Please sign in to comment.